I want to use captum to explain the dgl model,There was a similar discussion beforecaptum_dgl_demo,But I have a few questions here:

  • When doing node classification tasks, we use captum’s IG algorithm to interpret the features of the edges. What is the meaning of the final interpretation result? Because in the explanation process, we did not select the specific node to be explained.

  • In node calssification tasks, If I want to specify a node for explanation, what should I do ?

Let’s consider the example below.

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients
from dgl.nn.pytorch import GraphConv
from functools import partial

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, out_feats)
        self.conv2 = GraphConv(out_feats, out_feats)

    def forward(self, nfeats, g, nid=None):
        h = self.conv1(g, nfeats)
        h = F.relu(h)
        if nid is not None:
            return self.conv2(g, h)[nid:nid+1]
            return self.conv2(g, h)

g = dgl.graph(([1, 2, 3, 4, 4], [0, 1, 2, 3, 4]))
feat_size = 5
nfeats = torch.randn(g.num_nodes(), feat_size)
num_cls = 10
model = GCN(feat_size, num_cls)
ig = IntegratedGradients(partial(model.forward, g=g))
result = ig.attribute(nfeats, target=0, internal_batch_size=g.num_nodes(), n_steps=50)

result is a tensor whose shape is the same as nfeats. It contains the importance scores of the input features nfeats in predicting class 0 for all nodes. Same for edge features.

If you want to perform importance attribution for a single node, than you can pass nid in forward.

