Problem about use graphbolt.GPUCacheFeature. RuntimeError: Keys should be on a CUDA device

I try to add GPUCacheFeature in dgl’s graphbolt tutorials. I use the link prediction code.
first, load the dataset, both graph and data are on cpu

dataset = gb.BuiltinDataset("cora").load()
device="cpu"
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]

then define the datapipe, i put the copy_to at last to do sampling on cpu.

from functools import partial
def create_train_dataloader():
    datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(graph, [5, 5])
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to("cuda:0")
    return gb.DataLoader(datapipe)

create GPUCacheFeature

feature._features[('node',None,'feat')] = gb.GPUCachedFeature(feature._features[('node',None,'feat')], 512)

define model and start train

class SAGE:
# same as tutorial, skipped

in_size = feature.size("node", None, "feat")[0]
model = SAGE(in_size, 128).to("cuda:0") # move model to cuda
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(3):
# train code

when i run the train code, it reports

RuntimeError: Keys should be on a CUDA device.
This exception is thrown by __iter__ of FeatureFetcher(datapipe=MultiprocessingWrapper, edge_feature_keys=None, feature_store=TorchBasedFeatureStore(
    {(<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat'): <dgl.graphbolt.impl.gpu_cached_feature.GPUCachedFeature object at 0x7f6bfda4e4a0>}
), node_feature_keys=['feat'])

It seems keys to read cache should be also on gpu, but where can I move the keys to in gpu in the datapipe, which datapipe op should i use?

feature = dataset.feature.pin_memory_()
# graph = dataset.graph.pin_memory_()
# graph = dataset.graph.to("cuda:0")
def create_train_dataloader():
    datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
    # For GPU sampling, copy_to can be placed here too
    # For that, the graph needs to moved to GPU or should be pinned
    # datapipe = datapipe.copy_to("cuda:0", extra_attrs=["seed_nodes"])
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(graph, [5, 5])
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
    datapipe = datapipe.copy_to("cuda:0", extra_attrs=["input_nodes"])
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)

copy_to before fetch feature will ensure that the keys are on the GPU. Pinning the features will ensure that the features are GPU accessible.

Note: DGL 2.2 release in May will remove the requirement of using the extra_attrs argument of copy_to.

Thanks for replying. But after I pinning the feature and move copy_to before fetch_feature. It still reports the same error.

My full code is below:


import dgl
import dgl.graphbolt as gb 
import torch
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import time

dataset = gb.BuiltinDataset("cora").load()

device="cpu"
graph = dataset.graph.to(device)
feature = dataset.feature.to(device).pin_memory_() # pin the feature
assert feature.is_pinned()
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]

from functools import partial
def create_train_dataloader():
    datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(graph, [5, 5])
    datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
    datapipe = datapipe.copy_to("cuda:0") # move copy_to before fetch_feature
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)

feature._features[('node',None,'feat')] = gb.GPUCachedFeature(feature._features[('node',None,'feat')], 512)


class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x


in_size = feature.size("node", None, "feat")[0]
model = SAGE(in_size, 128).to("cuda:0")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(3):
    model.train()
    total_loss = 0
    for step, data in tqdm(enumerate(create_train_dataloader())):
        # Get node pairs with labels for loss calculation.
        compacted_pairs, labels = data.node_pairs_with_labels
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_pairs[0]] * y[compacted_pairs[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}")

full error message:

RuntimeError                              Traceback (most recent call last)
Cell In[10], line 7
      5 model.train()
      6 total_loss = 0
----> 7 for step, data in tqdm(enumerate(create_train_dataloader())):
      8     # Get node pairs with labels for loss calculation.
      9     compacted_pairs, labels = data.node_pairs_with_labels
     10     node_feature = data.node_features["feat"]

File ~/miniconda3/envs/dgl-dev-gpu-117/lib/python3.10/site-packages/tqdm/notebook.py:250, in tqdm_notebook.__iter__(self)
    248 try:
    249     it = super(tqdm_notebook, self).__iter__()
--> 250     for obj in it:
    251         # return super(tqdm...) will not catch exception
    252         yield obj
    253 # NB: except ... [ as ...] breaks IPython async KeyboardInterrupt

File ~/miniconda3/envs/dgl-dev-gpu-117/lib/python3.10/site-packages/tqdm/std.py:1181, in tqdm.__iter__(self)
   1178 time = self._time
   1180 try:
-> 1181     for obj in iterable:
   1182         yield obj
   1183         # Update and possibly print the progressbar.
   1184         # Note: does not call self.update(1) for speed optimisation.
...

RuntimeError: Keys should be on a CUDA device.
This exception is thrown by __iter__ of FeatureFetcher(datapipe=MultiprocessingWrapper, edge_feature_keys=None, feature_store=TorchBasedFeatureStore(
    {(<OnDiskFeatureDataDomain.NODE: 'node'>, None, 'feat'): <dgl.graphbolt.impl.gpu_cached_feature.GPUCachedFeature object at 0x7f313032d780>}
), node_feature_keys=['feat'])

You forgot the extra_attrs argument of copy_to, paste the following in your code and it should work:
datapipe = datapipe.copy_to("cuda:0", extra_attrs=["input_nodes"])

This requirement will be dropped in DGL’s 2.2 release.