How to get input feature gradient with dataloader

I want to get the gradient of input
This is my code

class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h

def backward_hook(module, grad_input, grad_output):
    global input_grad
    input_grad= grad_input[1].data
 
g = dgl.graph(([0, 2, 0, 1], [1, 0, 3, 2]), idtype=torch.int32)  
g.ndata['feat'] = torch.arange(32).reshape((4, 8)).float()  
g.ndata['feat'].requires_grad_()  
g = dgl.add_reverse_edges(g, copy_ndata=True)  
g.ndata['label'] = torch.ones((g.num_nodes(), 1)).long()  

  

model = SAGE(8, 4, 2)  
opt = torch.optim.Adam(model.parameters())  
criterion = torch.nn.CrossEntropyLoss()  
model.conv1.fc_self.register_backward_hook(backward_hook)  
  
model.train()  
  
  
sampler = dgl.dataloading.NeighborSampler(  
    [1, 2],  
    replace=True,  
    prefetch_node_feats=['feat'],  
    prefetch_labels=['label']  
)  
  
train_nids = torch.arange(g.num_nodes()).int()  
  
dataloader = dgl.dataloading.DataLoader(  
        g,  
        train_nids,  
        sampler,  
        batch_size=2,  
        shuffle=False,  
        drop_last=False,  
        num_workers=0  
)  
  
  
for input_nodes, output_nodes, blocks in dataloader:  
    input_features = blocks[0].srcdata['feat']   
    node_labels = blocks[-1].dstdata['label']   
    node_labels = torch.squeeze(node_labels)  
    node_predictions = model(blocks, input_features)    
    loss = criterion(node_predictions, node_labels.long())  
  

    opt.zero_grad()  
    loss.backward()  
    opt.step()

 

This will cause error

Traceback (most recent call last):
File “dataloader.py”, line 334, in _prefetch
feats = recursive_apply(batch, _prefetch_for, dataloader)
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in recursive_apply
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in recursive_apply
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1132, in
return [recursive_apply(v, fn, *args, **kwargs) for v in data]
File “ lib\site-packages\dgl\utils\internal.py”, line 1134, in recursive_apply
return fn(data, *args, **kwargs)
File “ lib\site-packages\dgl\dataloading\dataloader.py”, line 286, in _prefetch_for
return _prefetch_for_subgraph(item, dataloader)
File “ lib\site-packages\dgl\dataloading\dataloader.py”, line 275, in _prefetch_for_subgraph
_prefetch_update_feats(
File “ lib\site-packages\dgl\dataloading\dataloader.py”, line 260, in prefetch_update_feats
feats[tid, key] = get_storage_func(parent_key, type).fetch(
File “ lib\site-packages\dgl\frame.py”, line 530, in fetch
return super().fetch(indices, device, pin_memory=pin_memory, **kwargs)
File “ lib\site-packages\dgl\storages\pytorch_tensor.py”, line 48, in fetch
return _fetch_cpu(
File “ lib\site-packages\dgl\storages\pytorch_tensor.py”, line 17, in _fetch_cpu
torch.index_select(tensor, 0, indices, out=result)
RuntimeError: index_select(): functions with out=… arguments don’t support automatic differentiation, but one of the arguments requires grad.

It seems that dataloader doesn’t support Tensor with autograd
If I don’t use neighbor sample, this code will work.
Is there a solution to this problem?
hope my question is clear

As torch.index_select() with out= don’t support auto diff, you could try to disable prefetch_node_feats=['feat'] in the neighbor sampler. Then call input_features = torch.index_select(g.ndata['feat'], 0, input_nodes) in the dataloader iteration.

could try with this workaround?

I disabled prefetch_node_feats=['feat'] and it works!
Thanks a lot

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.