Hi,
I am trying to build a recommender system using GNN. As usual on ML tasks, I want to split my dataset in train, test and validation sets.
Following the example of WWW20 Hands on Tutorial here, I consider splitting based on the edge ids.
eids = np.random.permutation(g.number_of_edges())
train_eids = eids[:int(len(eids) * 0.8)]
valid_eids = eids[int(len(eids) * 0.8):int(len(eids) * 0.9)]
test_eids = eids[int(len(eids) * 0.9):]
train_g = g.edge_subgraph(train_eids, preserve_nodes=True)
Then, when I train my model, it would be only on the train_g :
...
loss = model(train_g, features, neg_sample_size)
...
And when I report performance, it would be on the full graph, with only the test_eids :
acc = LPEvaluate(gconv_model, g, features, test_eids, neg_sample_size)
I have two questions regarding this architecture.
-
The ‘training signal’ in link prediction is the connectivity between 2 nodes, i.e. nodes connected by edges are similar, while nodes not connected by edges are dissimilar. However, when splitting the dataset in the way presented here, the ‘training signal’ is included in the graph g. Same goes for the train_g : one might consider that the network is only learning to predict high scores for nodes that are connected in the graph. Is this correct?
-
When computing metrics for recommender systems, we try to replicate as much as possible what would happen in a real setting. However, in this case, we report accuracy using the full graph g. This full graph already includes edges in the test set, i.e. ‘training signal’. This would mean that, in a user-item recommender systems, we would already know which items a user will click on, and we could use that information to correctly predict those items. How to insure that the model does not have access to that information?
One could propose that the inference is done on the train_g directly, where there is no information on the training edges. However, this would create different settings for the training set (where we have all the edges in the graph) and the test set (where we only have ‘past’ edges in the graph), which might lead to poor generalization.
Thanks a lot in advance!