GraphSAGE : Inference Performance

Hello ! I have been playing around with and reddit data set and trying to write an inference code. I intend to do inference on full test set to recreate results in original GraphSAGE papers. Looking at, it appears to be doing evaluation on all nodes of the entire reddit data set graph (nid in dataloading is being set as all nodes in graph (th.arange(g.number_of_nodes()))).

def inference(self, g, x, batch_size, device,args):
        dataloader = dgl.dataloading.NodeDataLoader(

And later filters test set nodes for accuracy calculation

return compute_acc(pred[val_nid], labels[val_nid])

This causes test set inference time (~48sec) to be way larger than single epoch training (~21sec).

If we strictly want to do inference/evaluate on test set nodes, shouldn’t we set val_nid to:

th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
# or
th.nonzero(test_g.ndata['test_mask'], as_tuple=True)[0]

and pass val_nid as nid in dataloader instead of th.arange(g.number_of_nodes()). This brings down inference function time to ~12.6sec. But it also reduces test accuracy to 86.5% from 94.5% (which is achieved by when evaluation is done on all nodes of the graph as opposed to only test set. Same accuracy test accuracy was achieved by authors of graphSAGE paper). If we go further one step, i.e., setting test_g to:

test_g = g.subgraph(g.ndata['test_mask'])

instead of:

test_g = g

in inductive_split() of, we can further bring down inference time to ~5.5 sec with accuracy of 92.78%.


  1. What is the correct way to do inference strictly on test set? I am trying to reproduce results of the paper (Inference on entire test for graphSAGE with mean aggregator was done in ~1sec with F1 minor score (nothing but test accuracy in of 94.5. So I am trying to see how close DGL implementation can come to original TF implementation. Hence it is critical to strictly perform inference on test set nodes.

  2. Also, can you reason fall in accuracy in case when I set nid to test set mask,

    nid = th.nonzero(~(test_g.ndata[‘train_mask’] | test_g.ndata[‘val_mask’]), as_tuple=True)[0] or
    th.nonzero(test_g.ndata[‘test_mask’], as_tuple=True)[0]


I trained a model ( for 10 epochs + inductive with following settings:

argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--aggr', type=str, default='mean')
argparser.add_argument('--num-epochs', type=int, default=10)
argparser.add_argument('--num-hidden', type=int, default=128)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=512)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.01)
argparser.add_argument('--dropout', type=float, default=0.0)
argparser.add_argument('--num-workers', type=int, default=4,

and dumped the trained model which was loaded in inference code which then called evaluate function.

Thanks !

The reason I put th.arange() instead of val_nid is that the evaluation of val_nid depends on the representation of nodes other than val_nid, which in turn depends on a bunch of other nodes again. If you set the nid to val_nid only, then the node representation of the nodes other than val_nid are never updated, hence the drop in accuracy.

To do inference on strictly the test set, I would imagine figuring out the nodes needed as input on the last layer, and the nodes needed as input on the second-to-last layer, etc., and only compute representations on that.

Thanks for explanation @BarclayII. Sounds rational.

The reason I put th.arange() instead of val_nid is that the evaluation of val_nid depends on the representation of nodes other than val_nid, which in turn depends on a bunch of other nodes again. If you set the nid to val_nid only, then the node representation of the nodes other than val_nid are never updated, hence the drop in accuracy

So if I understood you correctly, accuracy drops because representation of some test_set nodes might depend on nodes outside test_set. As a result, we end up getting not so accurate representation of such test_set nodes as we are not calculating representations of the nodes that it depends on which are outside test_set. Did I summarize it correctly?

Also wondering why accuracy goes up to ~93% when I set test_g = g.subgraph(g.ndata[‘test_mask’]) as opposed to entire graph and provide th.nonzero(test_g.ndata[‘test_mask’], as_tuple=True)[0] as nid in dataloader.

To do inference on strictly the test set, I would imagine figuring out the nodes needed as input on the last layer, and the nodes needed as input on the second-to-last layer, etc., and only compute representations on that.

Can you elaborate on this little bit. What do you mean by layer here (are you referring to K layers used for aggregation) ?

So if I understood you correctly, accuracy drops because representation of some test_set nodes might depend on nodes outside test_set. As a result, we end up getting not so accurate representation of such test_set nodes as we are not calculating representations of the nodes that it depends on which are outside test_set. Did I summarize it correctly?


Also wondering why accuracy goes up to ~93% when I set test_g = g.subgraph(g.ndata[‘test_mask’]) as opposed to entire graph and provide th.nonzero(test_g.ndata[‘test_mask’], as_tuple=True)[0] as nid in dataloader.

This somewhat surprises me since it means we are not losing much accuracy even if we only evaluate on the graph formed by the nodes in the test set. I guess this is specific to Reddit dataset and may or may not work on other datasets.

Can you elaborate on this little bit. What do you mean by layer here (are you referring to K layers used for aggregation) ?

Yes. By layer I mean the GraphSAGE layers used for aggregation. Basically, if you want to compute the output of some nodes from the L-th layer, you need to compute the output of those nodes as well as their neighbors from the (L-1)-th layer, so on so forth.

That being said, for computing validation accuracy in practice I would probably just go for neighbor sampling like what we have in the training set, instead of computing every single node’s representation on every single layer. The reason is that computing validation accuracy is a frequent operation, we don’t want to put too much computation in it, and it is OK to introduce some sort of randomness for model selection.

@BarclayII: Thanks for explanation. Much appreciated.