I want to get the gradient of input
This is my code
class SAGE(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super().__init__()
self.conv1 = dglnn.SAGEConv(
in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
self.conv2 = dglnn.SAGEConv(
in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = F.relu(h)
h = self.conv2(graph, h)
return h
def backward_hook(module, grad_input, grad_output):
global input_grad
input_grad= grad_input[1].data
g = dgl.graph(([0, 2, 0, 1], [1, 0, 3, 2]), idtype=torch.int32)
g.ndata['feat'] = torch.arange(32).reshape((4, 8)).float()
g.ndata['feat'].requires_grad_()
g = dgl.add_reverse_edges(g, copy_ndata=True)
g.ndata['label'] = torch.ones((g.num_nodes(), 1)).long()
model = SAGE(8, 4, 2)
opt = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()
model.conv1.fc_self.register_backward_hook(backward_hook)
model.train()
sampler = dgl.dataloading.NeighborSampler(
[1, 2],
replace=True,
prefetch_node_feats=['feat'],
prefetch_labels=['label']
)
train_nids = torch.arange(g.num_nodes()).int()
dataloader = dgl.dataloading.DataLoader(
g,
train_nids,
sampler,
batch_size=2,
shuffle=False,
drop_last=False,
num_workers=0
)
for input_nodes, output_nodes, blocks in dataloader:
input_features = blocks[0].srcdata['feat']
node_labels = blocks[-1].dstdata['label']
node_labels = torch.squeeze(node_labels)
node_predictions = model(blocks, input_features)
loss = criterion(node_predictions, node_labels.long())
opt.zero_grad()
loss.backward()
opt.step()
This will cause error
Traceback (most recent call last):
File “dataloader.py”, line 334, in _prefetch
feats = recursive_apply(batch, _prefetch_for, dataloader)
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in recursive_apply
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in recursive_apply
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1134, in recursive_apply
return fn(data, *args, **kwargs)
File “ lib\site-packages\dgl\dataloading\dataloader.py”, line 286, in _prefetch_for
return _prefetch_for_subgraph(item, dataloader)
File “ lib\site-packages\dgl\dataloading\dataloader.py”, line 275, in _prefetch_for_subgraph
_prefetch_update_feats(
File “ lib\site-packages\dgl\dataloading\dataloader.py”, line 260, in prefetch_update_feats
feats[tid, key] = get_storage_func(parent_key, type).fetch(
File “ lib\site-packages\dgl\frame.py”, line 530, in fetch
return super().fetch(indices, device, pin_memory=pin_memory, **kwargs)
File “ lib\site-packages\dgl\storages\pytorch_tensor.py”, line 48, in fetch
return _fetch_cpu(
File “ lib\site-packages\dgl\storages\pytorch_tensor.py”, line 17, in _fetch_cpu
torch.index_select(tensor, 0, indices, out=result)
RuntimeError: index_select(): functions with out=… arguments don’t support automatic differentiation, but one of the arguments requires grad.
It seems that dataloader doesn’t support Tensor with autograd
If I don’t use neighbor sample, this code will work.
Is there a solution to this problem?
hope my question is clear