Thanks again @BarclayII !! It is working with the dummy data, I leave the code here with visualisation for reference if someone wants to do a similar thing in the future
(I will do some trials on the Mutagenicity dataset to check if the explanations agree more or less with the known ground truth).
I have an extra question: I have a trained model of the form
class Classifier_gen(nn.Module):
def __init__(self, in_dim, hidden_dim_graph,hidden_dim1,n_classes,dropout,num_layers,pooling):
super(Classifier_gen, self).__init__()
if num_layers ==1:
self.conv1 = GraphConv(in_dim, hidden_dim1)
if num_layers==2:
self.conv1 = GraphConv(in_dim, hidden_dim_graph)
self.conv2 = GraphConv(hidden_dim_graph, hidden_dim1)
if pooling == 'att':
pooling_gate_nn = nn.Linear(hidden_dim1, 1)
self.pooling = GlobalAttentionPooling_mod(pooling_gate_nn)
self.classify = nn.Sequential(nn.Linear(hidden_dim1,hidden_dim1),nn.Dropout(dropout))
self.classify2 = nn.Sequential(nn.Linear(hidden_dim1, n_classes),nn.Dropout(dropout))
self.out_act = nn.Sigmoid()
def forward(self, g,num_layers,pooling):
h = g.ndata['h_n'].float()
h = F.relu(self.conv1(g, h))
if num_layers=='2':
h = F.relu(self.conv2(g, h))
g.ndata['h'] = h
if pooling == "max":
hg = dgl.max_nodes(g, 'h')
elif pooling=="mean":
hg = dgl.mean_nodes(g, 'h')
elif pooling == "sum":
hg = dgl.sum_nodes(g, 'h')
elif pooling =='att':
# Calculate graph representation by averaging all the node representations.
[hg,g2] = self.pooling(g,h)
g2 = hg
a2=self.classify(hg)
a3=self.classify2(a2)
return self.out_act(a3),g2,hg
And I am trying to get the node features relevance as we were doing before (my variable âmodelâ is the trained classifier).
When I do
ig = IntegratedGradients(partial(model, G,num_layers=2,pooling='att'))
mask = ig.attribute(G, target=0, internal_batch_size=1, n_steps=50)
I get the error: AssertionError:
inputsmust have type torch.Tensor but <class 'dgl.heterograph.DGLHeteroGraph'> found:
I have tried to do some wrapping as
def tmp(G):
return model(G, G.ndata['h_n'].float(),num_layers=2,pooling='att')
ig = IntegratedGradients(tmp)
mask = ig.attribute(G, target=0, internal_batch_size=1, n_steps=50)
But I get the same error 
Do you have any idea if it is possible to make it work? Since the forward function of the classifier takes a lot of inputs, defines the node features already from the graph and returns a few things as well, I am worried it might not work. Should I redefine the classifier then? Thanks!