Training GNN for Edge Classification with Neighborhood Sampling on GPU

Hello, I’m running into a problem regarding training an edge classification model for heterogeneous graphs on GPU.
I built a binary heterogeneous bipartite graph and transferred it to the GPU as follows:

hetero_graph = dgl.heterograph({
    ('left', 'point_to', 'right'): (graph_code_left, graph_code_right),
    ('right', 'point_to_by', 'left'): (graph_code_right, graph_code_left)
}).to(device)
g = hetero_graph
g = g.to(device)
g.device
device(type='cuda', index=0)

I created a PyTorch DataLoader that iterates the training edge ID array train_eids in batches, putting the resulting list of MFGs onto the GPU.

train_dataloader = dgl.dataloading.EdgeDataLoader(
    g, train_eid_dict, sampler,device = device,
    exclude='reverse_types',
    reverse_etypes={
                    'point_to': 'point_to_by', 'point_to_by': 'point_to'},
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=False)

I made sure that the device of the entire train_dataloader is on the GPU.

train_dataloader.device
device(type='cuda', index=0)

But when I try to use the above materials for model training (in fact, the code as a whole can run normally when using the CPU), the following error will appear:

C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\base.py:45: DGLWarning: DGLGraph.__len__ is deprecated.Please directly call DGLGraph.number_of_nodes.
  return warnings.warn(message, category=category, stacklevel=1)
C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\base.py:45: DGLWarning: DGLGraph.__len__ is deprecated.Please directly call DGLGraph.number_of_nodes.
  return warnings.warn(message, category=category, stacklevel=1)
Traceback (most recent call last):
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__
    data = self._next_data()
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\dataloading\pytorch\__init__.py", line 130, in collate
    result = super().collate(items)
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\dataloading\dataloader.py", line 722, in collate
    return self._collate(items)
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\dataloading\dataloader.py", line 626, in _collate
    pair_graph = self.g.edge_subgraph(items)
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\utils\internal.py", line 904, in _fn
    return func(*args, **kwargs)
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\subgraph.py", line 297, in edge_subgraph
    sgi = graph._graph.edge_subgraph(induced_edges, preserve_nodes)
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\heterograph_index.py", line 824, in edge_subgraph
    return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\_ffi\_ctypes\function.py", line 188, in __call__
    check_call(_LIB.DGLFuncCall(
  File "C:\Users\The Wind\anaconda3\envs\pytorch\lib\site-packages\dgl\_ffi\base.py", line 64, in check_call
    raise DGLError(py_str(_LIB.DGLGetLastError()))
dgl._ffi.base.DGLError: [18:36:52] C:\Users\Administrator\dgl-0.5\src\graph\heterograph.cc:51: Check failed: hg->Context().device_type != kDLGPU: Edge subgraph with relabeling does not support GPU.

What is the reason for this? How can I fix this?
Thanks!

It seems that this operation does not support DGLGraphs on GPU. Removing g = g.to(device) should address the issue.

Do you mean if I want to training GNN for Edge Classification with Neighborhood Sampling, I can only use CPU?

No, you can move the sampled data from the dataloader to GPU within the for loop. However, it seems that you cannot pass a graph on GPU to EdgeDataLoader. Also EdgeDataLoader is now deprecated, see the doc here.

I see, thank you very much for your reply!