Explainability for Heterogeneous Graphs

Hi there!

I saw the great post about using with dgl for post-hoc explainability ( I was wondering if this would also be possible for heterogeneous graphs with multi modal input?

I was especially wondering how multiple input feature tensors could be integrated into the forward function. And considering that we are using a positive and a negative graph, how could they be integrated?

Thanks a lot for any help in advance!

I think IntegratedGradients requires the input and output to be a single tensor, so you could wrap up your heterogeneous GNN with another function and interpret one node type at a time?

Here is an example using IntegratedGradients with heterogeneous graph link prediction:

import dgl
import dgl.nn as dglnn
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients
from functools import partial

dataset =

g = dataset[0]

# Some synthetic node features - replace it with your own.
for ntype in g.ntypes:
    g.nodes[ntype].data['x'] = torch.randn(g.num_nodes(ntype), 10)

# Your model...
class Model(torch.nn.Module):
    def __init__(self, etypes):
        self.conv1 = dglnn.HeteroGraphConv({etype: dglnn.SAGEConv(10, 10, 'mean') for etype in etypes})
        self.conv2 = dglnn.HeteroGraphConv({etype: dglnn.SAGEConv(10, 10, 'mean') for etype in etypes})
    # Say that your forward function takes in a dictionary of node features and returns a dictionary
    # of edge scores.
    def forward(self, pair_graph, blocks, x):
        h = self.conv2(blocks[1], self.conv1(blocks[0], x))
        with pair_graph.local_scope():
            pair_graph.ndata['h'] = h
            for etype in pair_graph.canonical_etypes:
                pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
            return pair_graph.edata['score']
m = Model(g.canonical_etypes)

# Minibatch sampling stuff...
sampler = dgl.dataloading.as_edge_prediction_sampler(
    dgl.dataloading.NeighborSampler([2, 2]), negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))
eid = {g.canonical_etypes[0]: torch.arange(g.num_edges(g.canonical_etypes[0]))}
# Let's iterate over and explain one edge at a time.
dl = dgl.dataloading.DataLoader(g, eid, sampler, batch_size=1)

# Define a function that takes in a single tensor as the first argument and also returns a
# single tensor.
def forward_for_one_node_type(x, node_type, pair_graph, blocks, x_dict, output_edge_type):
    x_dict = x_dict.copy()
    x_dict[node_type] = x
    return m(pair_graph, blocks, x_dict)[output_edge_type]

for input_nodes, pair_graph, neg_pair_graph, blocks in dl:
    ntype = g.ntypes[0]                     # explain one input node type at a time.
    output_etype = g.canonical_etypes[0]    # explain one output edge type at a time.
    input_dict = blocks[0].ndata['x']
    x_ntype = input_dict[ntype]
    ig = IntegratedGradients(partial(forward_for_one_node_type,
                                     node_type=ntype, pair_graph=pair_graph, blocks=blocks, x_dict=input_dict,
    print(ig.attribute(x_ntype, target=0, internal_batch_size=1, n_steps=50))

Please let me know if you have more questions.


@BarclayII , you are a true genius! I just tried it out in my case and adapted it, since I have multiple features for each node type, and I got it working for that, too!

Thats just amazing, I truly thank you a lot :pray:

Hi @BarclayII ,

I actually do have one more question. I was wondering how it would be possible to use other methods from captum in this case, such as GradientShap or DeepLift.
I just tried it out by switching to another explainer method, and the code broke with the following errors:
For GradientShap:

Exception has occurred: DGLError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Expect number of features to match number of nodes (len(u)). Got 59900 and 5990 instead.
  File "/home/skrix/.conda/envs/env_explainability/lib/python3.8/site-packages/dgl/", line 4119, in _set_n_repr
    raise DGLError('Expect number of features to match number of nodes (len(u)).'

(btw 5990 is the number of nodes for the current node type, and the batch size is 100 in the dataloader, therefore probably 59900?)

For DeepLift:

Exception has occurred: AttributeError      
'functools.partial' object has no attribute 'register_forward_pre_hook'
  File "/home/skrix/.conda/envs/env_explainability/lib/python3.8/site-packages/captum/attr/_core/", line 596, in _hook_main_model
    self.model.register_forward_pre_hook(pre_hook),  # type: ignore

I can see that IntegratedGradients has the argument of internal_batch_size and this is probably why it works. Could you think of a way to adapt the code to work for GradientShap, too?

For DeepLift, I think you will need to wrap the partial function above as a PyTorch module instead, since register_forward_pre_hook only exists in PyTorch NN modules.
For GradientShap, maybe you could try setting batch_size=1, and show me how you choose the baselines?

Hi @BarclayII ,

So for GradientShap:
I construct the baselines like this, given that the batch_size in the dataloader is set to 1:

baselines = th.zeros(x_ntype.shape, device=device)
>>> torch.Size([5990, 500])

This still throws the error as mentioned above.
What I then tried out is to generate subgraphs that only have that one edge, and therefore would also only have the features for these two nodes in their ndata. This worked!

# Find all etype_1 edges and run analysis for each edge 

edges = test_graph.edges(etype='etype_1')
ntype_1_indices, ntype_2_indices = edges
ntype_1_indices =
ntype_2_indices =

if subgraph_flag:
    subgraph_collector = dict()
    for drug_index, condition_index in zip(ntype_1_indices, ntype_2_indices):
        subgraph = dgl.node_subgraph(
                'ntype_1': ntype_1_index,
                'ntye_2': ntype_2_index,

        subgraph_collector[tuple([ntype_1_index, ntype_2_index])] = subgraph


>>> torch.Size([1, 500])

This now works and produces results for individual test edges for GradientShap.

Hi @BarclayII !

I was wondering if it is possible to also do the IG analysis for the negative samples (neg_pair_graph) and specifically doing it for each edge individually? If you analyse an entire graph consisting of multiple edges there is a problem: We use the node features as inputs to be analysed, but the nodes can participate in multiple edges that are tested in the graph. So as you just get the importances for the nodes in total, you cannot trace back to which prediction of an edge a particular node feature was important. Therefore, it is crucial to be able to analyse edges individually.
Do you have an idea how to analyse the negative sampled edges individually?

You could subgraph your neg_pair_graph with relabel_nodes=False. Or alternatively set the number of negative examples per positive as 1.

@BarclayII I do have set the number of negative examples per positive example as 1. The issue is that when I set the batch size in the dataloader to 1 (to analyse each sample individually), the number of nodes in the graph (should be 2) does not correspond to the number of features (which is the higher since it is the number of total nodes in the original graph we sample from). Is there a way to overcome this issue?

What would relabel_nodes=False have as an advantage here? Do you mean creating subgraphs with one edge each?