Gradient wrt Input using DGLGraph

Hello!

I’m new to DGL and am trying to use the AGNN implemented here and here.

Specifically, after training the model, I would like 1) run a forward pass with a mask over the adjacency matrix and corresponding features, 2) compute the cross entropy loss 3) perform a backward pass with respect to the predicted class of a specific node.

However, IIUC, the adjacency matrix is wrapped as a DGL object (self.g), and it is very unclear to me how to pull out these gradients.

Steps 1-3 seem doable, however, I would like get the gradient wrt to the masked adjacency matrix that was also fed through the network (not just the features!), and I have no idea how to do this.

In PyTorch, I would set adj.requires_grad = True, and after backpropping, the gradients are there. Is it possible to do something similar with DGL?

tl;dr: Basically, I’d like a saliency-map equivalent over the input adjacency matrix, but I don’t know how to interact with the DGLGraph object to pull this out.

I appreciate any help! Sorry if this is a super noob question!

I saw another post: How to compute the gradient of edge with respect to some loss but truthfully, I still wasn’t sure how this would work in my context.

How about something as follows:

import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn

class ToyModel(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats=1):
        super(ToyModel, self).__init__()
        self.f = nn.Sequential(
            nn.Linear(in_feats, hidden_feats), 
            nn.ReLU(),
            nn.Linear(hidden_feats, 1))

    def forward(self, g, feats, edge_weights):
        feats = self.f(feats)
        g = g.local_var()
        g.ndata['h'] = feats
        g.edata['e'] = edge_weights
        g.update_all(fn.u_mul_e('h', 'e', 'm'), fn.sum('m', 'h'))
        return g.ndata.pop('h')

g = dgl.DGLGraph((np.array([0, 1, 1]), np.array([1, 2, 2])))
node_feats = torch.randn(g.number_of_nodes(), 2)
edge_weights = torch.ones((g.number_of_edges(), 1), requires_grad=True)
model = ToyModel(node_feats.shape[1], node_feats.shape[1], 1)
logits = model(g, node_feats, edge_weights)
labels = torch.ones(g.number_of_nodes(), 1)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(logits, labels)
loss.backward()
print(edge_weights.grad)
# tensor([[0.0137],
#         [0.0247],
#         [0.0247]])