Reddit Dataset does not work with DataParallel

#1

Hi!
I have just started using DGL and the main motivation was that it simplifies working with graph data a lot.

However, I have encountered a following issue: whenever I try to use to test my model on Reddit dataset, it does not work.

If I use torch.nn.DataParallel, it throws a following error: dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 58242 and 232965 instead.
This is probably due to the fact that Reddit dataset is not exactly compatible with PyTorch dataloader from what I can see in the code.

I also tried using torch.nn.parallel.DistributedDataParallel (with one program - multiple GPUs approach), however, it freezes when I try to load model into torch.nn.parallel.DistributedDataParallel.

However, because the graph is quite large, it does not fit into any of the machines that I have available.

Is there anything I can do about that? Thanks.

#2

Hi,

It seems your data is not split in the right way. Do you use multiprocessing workers or any collate function?

#3

Thank you for you reply.
Not really, I am not using anything. I just load data and save it in the variables. My code looks as following:

    import torch
from dgl.data import register_data_args, load_data


def get_dataset(args, device):
    data = load_data(args)
    features = torch.DoubleTensor(data.features).to(device)
    labels = torch.LongTensor(data.labels).to(device)
    train_mask = torch.ByteTensor(data.train_mask.astype(int)).to(device)
    val_mask = torch.ByteTensor(data.val_mask.astype(int)).to(device)
    test_mask = torch.ByteTensor(data.test_mask.astype(int)).to(device)
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    return (
        data, 
        features,
        labels,
        train_mask,
        val_mask,
        test_mask,
        in_feats,
        n_classes,
        n_edges,
    )

After that, I use the data by passing it to the model directly, similarly to DGL tutorial.

#4

Hi,

Reddit Graph has only one graph, right? How would you parallelize it to multiple gpus?