hey,
I am using the following pytorch code for loading cora dataset:
import dgl
from dgl.data import CoraGraphDataset
import torch
import traceback
def collate(sample):
graphs, feats, labels =map(list, zip(*sample))
print(graphs. feats. labels)
graph = dgl.batch(graphs)
feats = torch.from_numpy(np.concatenate(feats))
labels = torch.from_numpy(np.concatenate(labels))
return graph, feats, labels
# Load PPI dataset
dataset = CoraGraphDataset()
# Create PyTorch DataLoader with custom collate function
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate)
# Loop through the dataloader
for i, (bg, feats, label) in enumerate(dataloader):
print(bg, label)
But i get this error:
Using backend: pytorch
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
Traceback (most recent call last):
File "lol.py", line 21, in <module>
for i, (bg, feats, label) in enumerate(dataloader):
File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "lol.py", line 7, in collate
graphs, feats, labels =map(list, zip(*sample))
File "/home/kartik.anand.19031/.conda/envs/MSKD/lib/python3.7/site-packages/dgl/heterograph.py", line 2152, in __getitem__
raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))
dgl._ffi.base.DGLError: Invalid key "0". Must be one of the edge types.