Explainability - self attention

Is it possible to add a layer for “self-attention” (or use self attention pooling) that only takes into account the updated node features (not neighbourhoods) to rank the most useful nodes in the model?

I have been trying to get my head around how to do it but was unable to get a solution :frowning: This can be used to increase the explainability of the models

In the case of graph property prediction, you may use GlobalAttentionPooling to compute a graph representation out of the representation of nodes in it. You can treat the “weights” from the output of the softmax function as the significance of nodes.

Many thanks Mufeili :slight_smile:

I am trying to implement it and I am having some issues!

I have a batch size of 20 graphs and each graph has 8 initial node features, the hidden dimension for the graph is 4.

If I use something like my regular classifier substituting my previous pooling method by the global attention pooling (e.g. below) I get the error:

TypeError: expected Tensor as element 0 in argument 0, but got Global Attention Pooling

class Classifier2(nn.Module):
    def __init__(self, in_dim, hidden_dim,added_dim, n_classes):
        super(Classifier2, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.dense1 = nn.Linear(hidden_dim+added_dim,hidden_dim+added_dim)
        self.classify = nn.Linear(hidden_dim+added_dim, n_classes)

    def forward(self, g,extra_feats):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.ndata['h_n'].float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = GlobalAttentionPooling(g)
        d1 = self.dense1(torch.cat((hg, extra_feats), dim=1))
        return self.classify(d1)

I’ve tried to see what do I get if I just return hg as above and the size is 20x1 (batch size x 1) and I wonder how to interpret that

I’ve also tried to check what I get from hg = GlobalAttentionPooling.forward(g) and it has dimensions 13x1 which I don’t understand.

Could you help me a bit here please? I’d like to have, as you were mentioning before, the weights to use as significance for the nodes :smiley:

I’m also curious if something similar can be done with the set2set pooling?

You did not use GlobalAttentionPooling correctly. You need to initialize an GlobalAttentionPooling instance in __init__ and use the instance in the forward function of Classifier2. See an example here and the doc here.

You can do something similar with the set2set pooling as well.

Many thanks again for the examples and kind help! I finally made it work but the output of the pooling is still batch size x feature size (which I see from the docs it is what it should be).

Looking to the source code of GlobalAttentionPooling, should I modify it to return the feat*gate result? Is that the weight from the softmax you were referring to in the first reply?

in the case of set2set, is the the feat*alpha what I should try to get?

Many thanks in advance!

For GlobalAttentionPooling, you can use gate for node importance. For Set2Set, you can use alpha for node importance. Note that in the latter case there are multiple alpha and you may want to examine all of them.

Yay, got it! Many thanks! Very much appreciated :smiley:

I think you can try that.