Not Able to Iterate through dataloader

hey,
I am using the following pytorch code for loading cora dataset:

import dgl
from dgl.data import CoraGraphDataset
import torch
import traceback

def collate(sample):
    graphs, feats, labels =map(list, zip(*sample))
    print(graphs. feats. labels)
    graph = dgl.batch(graphs)
    feats = torch.from_numpy(np.concatenate(feats))
    labels = torch.from_numpy(np.concatenate(labels))
    return graph, feats, labels

# Load PPI dataset
dataset = CoraGraphDataset()

# Create PyTorch DataLoader with custom collate function
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate)

# Loop through the dataloader
for i, (bg, feats, label) in enumerate(dataloader):
    print(bg, label)

But i get this error:

Using backend: pytorch
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Traceback (most recent call last):
  File "lol.py", line 21, in <module>
    for i, (bg, feats, label) in enumerate(dataloader):
  File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "lol.py", line 7, in collate
    graphs, feats, labels =map(list, zip(*sample))
  File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/dgl/heterograph.py", line 2152, in __getitem__
    raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))
dgl._ffi.base.DGLError: Invalid key "0". Must be one of the edge types.

Hi @k-styles ,

In DGL, CoraGraphDataset() is a single graph, so you cannot iterate over it. You may access the graph topology via dataset[0].

@czkkkkkk Thanks for your reply,
I tried with PPIDataset() as well, but I am getting the same error.

Basically I want to load the cora dataset in a way similar to the following:

def get_data_loader(args):
    if args.dataset == 'ppi':
        train_dataset = LegacyPPIDataset(mode='train')
        valid_dataset = LegacyPPIDataset(mode='valid')
        test_dataset = LegacyPPIDataset(mode='test')

        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=4, shuffle=True)
        fixed_train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=4)
        valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=2)
        test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=2)

        n_classes = train_dataset.labels.shape[1]
        num_feats = train_dataset.features.shape[1]
        g = train_dataset.graph
        data_info = {}
        data_info['n_classes'] = n_classes
        data_info['num_feats'] = num_feats
        data_info['g'] = g
        return (train_dataloader, valid_dataloader, test_dataloader, fixed_train_dataloader), data_info

    elif args.dataset == 'cora':
        # Load cora dataset

But at the end, when I try iterate over the dataloader object, I get the error I mentioned in the post above.

Could you please let me know, how should I go about doing this?

I think you should use GraphDataloader as in https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/train_ppi.py.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.