Best way to send (batched) graphs to GPU


If I want to put a graph on a GPU (say “cuda:0”), how can I do that? I’m looking for the equivalent of a“cuda:0”) operation. I would like to do this because I have a large dataset of small graphs and I’m loading them with a DataLoader and batching them together as shown in the " Batched Graph Classification With DGL" tutorial. I do not want to put the entire dataset on the GPU (because it’s too big), but I do want to send the graph batches to the GPU when I am doing training and inference.

Any help would be appreciated.



you can send batch graph at the beginning of the forward part, such as h = bg.ndata[‘h’]; h =



Thanks@taiky for the answer. Currently, you need to move each features one-by-one. I’m thinking about providing a helping function like:

g = ...  # move all the feature data to device

What do you think?

1 Like


Thanks taiky and minjie!
Yes, I think such a helper function would be useful. I’m a dgl and pytorch noob so it’s not clear to me which parts of a graph g need to be on the GPU for the model to perform calculations on it. Is it just g.ndata and g.edata?



Only g.ndata and g.edata. Where is the graph stored is currently maintained by DGL. It is born in CPU, but will cache to GPU if required. We also invalidate the cache if the graph is mutated.



Hi minjie!
If I send all of my g.edata and g.ndata to the device before running model.forward(), I notice that g.in_degrees().view(-1,1).float() is still on my CPU. I could always send it to GPU with .to(“cuda:0”), but that would require me to pass the device (“cuda:0”) as an argument to my forward() function. Is there any nice way to get around this?




How about .to(input_tensor.device) ?



In case it’s helpful, here’s what I did:

def send_graph_to_device(g, device):
    # nodes
    labels = g.node_attr_schemes()
    for l in labels.keys():
        g.ndata[l] = g.ndata.pop(l).to(DEVICE, non_blocking=True)
    # edges
    labels = g.edge_attr_schemes()
    for l in labels.keys():
        g.edata[l] = g.edata.pop(l).to(DEVICE, non_blocking=True)
return g
1 Like