How to correctly train the GraphSAGE model on the Papers100M dataset?

I have seen in many papers that the GraphSAGE model can achieve an accuracy of around 0.65 on the Papers100M dataset. However, when I follow their methods for training, I can only achieve an accuracy of around 0.56. I would like to know where the issue might be and how to set the correct hyperparameters to obtain the expected results.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
from dgl.data import AsNodePredDataset
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import argparse
import ast
import sklearn.metrics
import numpy as np
import time

class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size,num_layers=2,dropout=0.5):
        super().__init__()
        self.layers = nn.ModuleList()
        # three-layer GraphSAGE-mean
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
        for _ in range(num_layers - 2):
            self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.hid_size = hid_size
        self.out_size = out_size

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g,device, batch_size):
        """Conduct layer-wise inference to get all the node embeddings."""
        feat = g.ndata['feat']
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
        dataloader = DataLoader(
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=batch_size, shuffle=False, drop_last=False,
                num_workers=0)
        buffer_device = torch.device('cpu')
        pin_memory = (buffer_device != device)

        for l, layer in enumerate(self.layers):
            y = torch.empty(
                g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size,
                device=buffer_device, pin_memory=pin_memory)
            feat = feat.to(device)
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, position=0):
                x = feat[input_nodes]
                h = layer(blocks[0], x) # len(blocks) = 1
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                # by design, our output nodes are contiguous
                y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
            feat = y
        return y


def evaluate(model, graph, dataloader):
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
            x = blocks[0].srcdata['feat']
            ys.append(blocks[-1].dstdata['label'].cpu().numpy())
            y_hats.append(model(blocks, x).argmax(1).cpu().numpy())
        predictions = np.concatenate(y_hats)
        labels = np.concatenate(ys)
    return sklearn.metrics.accuracy_score(labels, predictions)

def layerwise_infer(device, graph, nid, model, batch_size):
    model.eval()
    with torch.no_grad():
        pred = model.inference(graph, device, batch_size) # pred in buffer_device
        pred = pred[nid]
        label = graph.ndata['label'][nid].to(pred.device)
    return sklearn.metrics.accuracy_score(label.cpu().numpy(), pred.argmax(1).cpu().numpy())

def train(args, device, g, dataset, model):
    train_idx = dataset.train_idx.to(device)
    val_idx = dataset.val_idx.to(device)
    sampler = NeighborSampler(args.fanout)
    use_uva = (args.mode == 'mixed')
    train_dataloader = DataLoader(g, train_idx, sampler, device=device,
                                  batch_size=1024, shuffle=False,
                                  drop_last=False, num_workers=0,
                                  use_uva=use_uva)

    val_dataloader = DataLoader(g, val_idx, sampler, device=device,
                                batch_size=1024, shuffle=True,
                                drop_last=False, num_workers=0,
                                use_uva=use_uva)

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
    for epoch in range(20):
        model.train()
        total_loss = 0
        startTime = time.time()
        count = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']
            y_hat = model(blocks, x)
            y = y.to(torch.int64)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
            count = it
        trainTime = time.time()-startTime
        acc = evaluate(model, g, val_dataloader)
        print("| Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | Time {:.3f}s | Count {} |"
              .format(epoch, total_loss / (it+1), acc.item(), trainTime, count))
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'],
                        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
                             "'puregpu' for pure-GPU training.")
    parser.add_argument('--fanout', type=ast.literal_eval, default=[10, 10, 10], help='Fanout value')
    parser.add_argument('--layers', type=int, default=3, help='Number of layers')
    parser.add_argument('--dataset', type=str, default='ogb-papers100M', help='Dataset name')
    parser.add_argument('--maxloop', type=int, default=20, help='max loop number')
    parser.add_argument('--model', type=str, default="SAGE", help='train model')
    args = parser.parse_args()
    if not torch.cuda.is_available():
        args.mode = 'cpu'
    print(f'Training in {args.mode} mode.')

    dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-papers100M'))
    g = dataset[0]
    g = g.to('cuda' if args.mode == 'puregpu' else 'cpu')
    device = torch.device('cpu' if args.mode == 'cpu' else 'cuda')
    

    sym_g = dgl.to_bidirected(g, readonly=True)
    for key in g.ndata:
        sym_g.ndata[key] = g.ndata[key]
    g = sym_g
    in_size = g.ndata['feat'].shape[1]
    out_size = dataset.num_classes
    model = SAGE(in_size, 256, out_size,args.layers,0.0).to(device)

    # model training
    print('Training...')
    train(args, device, g, dataset, model)
    model.eval()

    sampler_test = NeighborSampler([20,50,50],
                                    prefetch_node_feats=['feat'],
                                    prefetch_labels=['label'])
    test_dataloader = DataLoader(g, dataset.test_idx, sampler_test, device=device,
                                    batch_size=4096, shuffle=True,
                                    drop_last=False, num_workers=0,
                                    use_uva=True)
    acc = evaluate(model, g, test_dataloader)
    print("Test Accuracy {:.4f}".format(acc.item()))
    print("-"*20)

Are you calling layerwise_infer() like the DGL examples: https://github.com/dmlc/dgl/tree/master/examples/pytorch/graphsage?

Hi, after using the linked code and modifying ‘papers’ to an undirected graph, I did see an improvement in accuracy; however, I am still achieving only around 0.62 accuracy. Therefore, I am curious about the actual accuracy situation or what modifications may be needed in my parameters to reach 0.65. My current parameters are: batch size = 1024, fanout = ‘10,10,10’, hidden size = 256, dropout = 0.0.

I couldn’t see the number 0.65 for GraphSAGE on the leaderboard. Do you know where the number 0.65 first showed up? If so, do their paper come with reproducing code of GraphSAGE?

Thank you for your response. I found information about training the papers100M dataset and achieving around 0.66 accuracy in Table 3 of Marius and Table 5 of DUCATI. Their open-source repositories are as follows:

When I submitted an issue to the author of the latter, he explicitly mentioned that training accurate results only requires the standard GraphSage model. Link to the issue.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.