How to evaluate on link prediction with Dataloaders

Firstly, the original edges pre-split into train/val/test set and NeighborSampler and negative_sampler.GlobalUniform in train/val Dataloaders is used while training the model.

  1. Now since these sampler edges change, how can I be sure I am evaluating on the same trained positive and negative edges for best training accuracy from the saved best model? Does setting seeds just as in training setting enough to recreate the same train positive and negative edges from the dataloader again?

  2. So, after we train the link prediction model, I can pass the entire original graph at the inference stage and get the node representations.
    Now, how to evaluate model on test set edges? I have the positive test set edges and similarly positive train/val edges. Since, I used Dataloader and Sampler to generate negative edges in the training setup, what would be the way to proceed here for test setup?

Thanks

For 2, is this good? or /

def eval_auc(pos_g, neg_g, emb):
    with torch.no_grad():
        # model.pred is some DotProduct or MLPPredictor saved with best model
        pos_score = model.pred(pos_g.to(args.device), emb)
        neg_score = model.pred(neg_g.to(args.device), emb)
        print('AUC', compute_auc(pos_score, neg_score))

neg_test_edges = dgl.sampling.global_uniform_negative_sampling(pos_test_g, num_samples = pos_test_g.num_edges())
neg_test_g = dgl.graph(num_nodes = pos_test_g.num_nodes(), data = neg_test_edges)
eval_auc(pos_test_g, neg_test_g, emb)

  1. I think so.
  2. Your code looks OK to me except that you could just replace pos_test_g with the graph that include the positive training/validation edges as well, since you probably don’t want to them as negative test edges.
  1. Thanks, I got it now, as in transductive link prediction here, the test graph should include all the train and val edges too, but I should use pos_test_g only for the num_samples as to maintain the ratio of 1:1 positive-negative test edges?

Correct.

(Minimum 20 characters)