Here is an example using IntegratedGradients with heterogeneous graph link prediction:
import dgl
import dgl.nn as dglnn
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients
from functools import partial
dataset = dgl.data.MUTAGDataset()
g = dataset[0]
# Some synthetic node features - replace it with your own.
for ntype in g.ntypes:
g.nodes[ntype].data['x'] = torch.randn(g.num_nodes(ntype), 10)
# Your model...
class Model(torch.nn.Module):
def __init__(self, etypes):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({etype: dglnn.SAGEConv(10, 10, 'mean') for etype in etypes})
self.conv2 = dglnn.HeteroGraphConv({etype: dglnn.SAGEConv(10, 10, 'mean') for etype in etypes})
# Say that your forward function takes in a dictionary of node features and returns a dictionary
# of edge scores.
def forward(self, pair_graph, blocks, x):
h = self.conv2(blocks[1], self.conv1(blocks[0], x))
with pair_graph.local_scope():
pair_graph.ndata['h'] = h
for etype in pair_graph.canonical_etypes:
pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
return pair_graph.edata['score']
m = Model(g.canonical_etypes)
# Minibatch sampling stuff...
sampler = dgl.dataloading.as_edge_prediction_sampler(
dgl.dataloading.NeighborSampler([2, 2]), negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))
eid = {g.canonical_etypes[0]: torch.arange(g.num_edges(g.canonical_etypes[0]))}
# Let's iterate over and explain one edge at a time.
dl = dgl.dataloading.DataLoader(g, eid, sampler, batch_size=1)
# Define a function that takes in a single tensor as the first argument and also returns a
# single tensor.
def forward_for_one_node_type(x, node_type, pair_graph, blocks, x_dict, output_edge_type):
x_dict = x_dict.copy()
x_dict[node_type] = x
return m(pair_graph, blocks, x_dict)[output_edge_type]
for input_nodes, pair_graph, neg_pair_graph, blocks in dl:
ntype = g.ntypes[0] # explain one input node type at a time.
output_etype = g.canonical_etypes[0] # explain one output edge type at a time.
input_dict = blocks[0].ndata['x']
x_ntype = input_dict[ntype]
ig = IntegratedGradients(partial(forward_for_one_node_type,
node_type=ntype, pair_graph=pair_graph, blocks=blocks, x_dict=input_dict,
output_edge_type=output_etype))
print(ig.attribute(x_ntype, target=0, internal_batch_size=1, n_steps=50))
break
Please let me know if you have more questions.