Dataloader issue for link prediction task with pytorch lightning and dgl

I made this from mix of dgl examples and recent github commit of graphsage lightning for GAT link prediction.
But the overfit with single batch is not working, I can’t figure out where exactly the issue is from error.
The usual lightning fit loop with training and validation worked fine, but the loss was not changing much there, so want to try overfit and this happened.

import numpy as np
from sklearn.metrics import roc_auc_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.data import CoraGraphDataset
import glob
import os
import itertools
from torchmetrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from gat import GAT

class DotPredictor(nn.Module):
    def forward(self, edge_subgraph, h):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = h
            edge_subgraph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return edge_subgraph.edata['score'][:, 0]

def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).to(args.device)
    return F.binary_cross_entropy_with_logits(scores, labels)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).cpu().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

class GATLightning(LightningModule):
    def __init__(self,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 heads,
                 activation,
                 feat_drop,
                 attn_drop,
                 negative_slope,
                 residual,
                 lr,
                 weight_decay):
        super().__init__()
        self.save_hyperparameters()

        self.module = GAT(num_layers, in_dim, num_hidden, num_classes, 
                        heads, activation, feat_drop, attn_drop, negative_slope, residual)
        self.pred = DotPredictor()

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

    def training_step(self, batch, batch_idx):
        input_nodes, train_pos_g, train_neg_g, mfgs = batch
        mfgs = [mfg.int().to(self.device) for mfg in mfgs]
        batch_inputs = mfgs[0].srcdata['feat']

        batch_pred = self.module(mfgs, batch_inputs)
        pos_score = self.pred(train_pos_g, batch_pred)
        neg_score = self.pred(train_neg_g, batch_pred)

        loss = compute_loss(pos_score, neg_score)
        self.log('train_loss', loss,
                 prog_bar=True, on_step=False, on_epoch=True, 
                 batch_size = batch_pred.size(0))

        train_acc = compute_auc(pos_score.detach(), neg_score.detach())
        self.log('train_acc', train_acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_nodes, val_pos_g, val_neg_g, mfgs = batch
        mfgs = [mfg.int().to(self.device) for mfg in mfgs]
        batch_inputs = mfgs[0].srcdata['feat']

        batch_pred = self.module(mfgs, batch_inputs)
        pos_score = self.pred(val_pos_g, batch_pred)
        neg_score = self.pred(val_neg_g, batch_pred)
        
        val_acc = compute_auc(pos_score, neg_score)
        self.log('val_acc', val_acc, prog_bar=True, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        input_nodes, test_pos_g, test_neg_g, mfgs = batch
        mfgs = [mfg.int().to(self.device) for mfg in mfgs]
        batch_inputs = mfgs[0].srcdata['feat']

        batch_pred = self.module(mfgs, batch_inputs)
        pos_score = self.pred(test_pos_g, batch_pred)
        neg_score = self.pred(test_neg_g, batch_pred)

        test_acc = compute_auc(pos_score, neg_score)
        self.log('test_acc', test_acc, prog_bar=True, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(itertools.chain(self.parameters(), self.pred.parameters()), lr=self.hparams.lr,  weight_decay=self.hparams.weight_decay)
        return optimizer

class DataModule(LightningDataModule):
    def __init__(self, dataset_name, raw_dir, data_cpu=False,
                 device=torch.device('cpu'), batch_size=1000, num_workers=0):
        super().__init__()
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.num_workers = num_workers
        print(num_workers)
        self.raw_dir = raw_dir
        self.device = device
        
    def prepare_data(self):
        if self.dataset_name == 'cora':
            if not os.path.exists(os.path.join(self.raw_dir, 'cora_v2')):
                CoraGraphDataset(raw_dir=self.raw_dir)
        else:
            raise ValueError('unknown dataset')

    def setup(self, stage):
        data = CoraGraphDataset(raw_dir=self.raw_dir)
        g = data[0]
        features = g.ndata['feat']
        num_feats = features.shape[1]
        n_classes = data.num_classes

        # g = dgl.remove_self_loop(g)
        # g = dgl.add_self_loop(g)
        self.g = g
        self.in_feats = num_feats
        self.n_classes = n_classes
        # g = g.formats(['csc'])

        eids = np.arange(g.number_of_edges())
        eids = np.random.permutation(eids)
        val_size = int(len(eids) * 0.2)
        test_size = int(len(eids) * 0.1)

        # self.sampler = dgl.dataloading.NeighborSampler([4]*args.num_layers)
        self.sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args.num_layers)
        # self.negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)
        self.negative_sampler = dgl.dataloading.negative_sampler.GlobalUniform(5)
        self.sampler = dgl.dataloading.as_edge_prediction_sampler(
                    sampler = self.sampler, negative_sampler=self.negative_sampler)

        self.train_g = dgl.remove_edges(g, eids[:val_size]).remove_self_loop().add_self_loop()
        self.val_g = dgl.remove_edges(g,np.concatenate([eids[:test_size], eids[val_size:]])).remove_self_loop().add_self_loop()
        self.test_g =  dgl.remove_edges(g,eids[test_size:]).remove_self_loop().add_self_loop()

    def train_dataloader(self):
       return dgl.dataloading.DataLoader(
                self.train_g, 
                torch.arange(self.train_g.number_of_edges()),
                self.sampler,
                # device = self.device, # lightning adds it
                batch_size=self.batch_size,
                shuffle=True,
                drop_last=False,
                num_workers=self.num_workers
                )
    
    def val_dataloader(self):
        return dgl.dataloading.DataLoader(
                self.val_g,
                torch.arange(self.val_g.number_of_edges()),
                self.sampler,
                batch_size=self.batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=self.num_workers
            )
    
    def test_dataloader(self):
        return dgl.dataloading.DataLoader(
                self.test_g,
                torch.arange(self.test_g.number_of_edges()),
                self.sampler,
                batch_size=self.batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=self.num_workers
            )

class argparserclass():
    def __init__(self):
        self.device = torch.device('cuda:1')
        self.dataset='cora'
        self.epochs=200
        self.num_heads=8
        self.num_out_heads=1
        self.num_hidden=16
        self.num_layers=2
        self.residual = False
        self.in_drop = .6
        self.attn_drop = .6
        self.lr = .005
        self.weight_decay = 5e-4
        self.negative_slope = 0.2
        self.num_workers=0
        self.batch_size = 500
        self.shuffle=False
        self.raw_dir="DATA/.dgl/"
        self.data_cpu = False
        self.in_dim = 1433
        self.num_classes = 7

args = argparserclass()
datamodule = DataModule(args.dataset, args.raw_dir, args.data_cpu, 
        args.device, args.batch_size, args.num_workers)
heads = ([args.num_heads] * (args.num_layers-1)) + [args.num_out_heads]

model = GATLightning(args.num_layers, args.in_dim, args.num_hidden, args.num_classes, heads, F.elu,
args.in_drop, args.attn_drop, args.negative_slope, args.residual, args.lr, args.weight_decay)

# trainer = Trainer(gpus = [1], fast_dev_run=True)
trainer = Trainer(gpus = [1], overfit_batches=2) # not working!?
trainer.fit(model, datamodule)

TypeError: torch.utils.data.dataloader.DataLoader.init() got multiple values for keyword argument ‘worker_init_fn’

Hi,

What’s your DGL and pytorch lightning version?

dgl-cuda11.3 0.8.0post1
pytorch-lightning 1.5.10

So, I tried the same code without pytorch-lightning dependency and when overfitting on one single batch, the training loss stayed almost the same for 200+ epochs, could anyone help with finding whether any misuse of dgl functions with the above link prediction code?

Thanks

Edit: loss is improving after I changed to sample 2 negative edges instead of 5 for a positive edge.

@BarclayII found some issue regarding this and patched it.
will try later with the fix and update about issue after the commits are finalized

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.