Attention Weights for A Specific Node

The code has changed since this solution was posted. How would I get the attention weights for a specific node these days? Appreciate the help!

I added a new parameter to the GATConv module (attn_all) and assigned it the graph.edata[‘a’] data. Is that the right direction?

The attention weights of all edges over all layers are stored with this line. graph.edata['a'] should give you a tensor of shape (N, H), where N stands for the number of edges and H stands for the number of heads.