Hi!
I am encountering problems when trying to send my graph to device for prediction. I do the following:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)
load_model_path = 'xxxxxxx.pt'
model = torch.load(load_model_path)
model.to(device)
print(next(model.parameters()).device)
user_feats = g.nodes['user'].data['feats'].to(device)
post_feats = g.nodes['post'].data['feats'].to(device)
feats = {'user': user_feats, 'post': post_feats}
print(post_feats.device)
print(user_feats.device)
g.to(device)
print(g.device)
updated_feats = model.sage(g, feats)
Hoever, the graph is still on cpu.
Device: cuda:0
cuda:0
cuda:0
cuda:0
cpu
Traceback (most recent call last):
File "inference.py", line 152, in <module>
updated_feats = model.sage(g, feats)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/code/dgl_recommendator/models.py", line 35, in forward
h = self.conv1(graph, inputs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/dgl/nn/pytorch/hetero.py", line 189, in forward
**mod_kwargs.get(etype, {}))
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/dgl/nn/pytorch/conv/graphconv.py", line 408, in forward
feat_src = feat_src * norm
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Any ideas?
It’s weird because during training I use this same tactic and it works perfectly.
Thanks.