I trying to perform semi-supervised node classification using a multi-head GAT. So far I have it working producing reasonable labels both with random features and with real features. However, I’m wanting to better understand and visualise the attention weights learned by the network.
I’ve seen the post on the visualisation of edge attention weights, which is not really what I’m after. I also read @mufeili papers on the statistical characterization of attention in graph neural networks and read a related post here. These are great in terms of understanding the importance of a given edge or node within your graph.
Specifically though I’m interested in visualizing the weights of the attention heads of the different layers as they pertain/transform the input features.
Is there any relevant discourse anyone would recommend regarding visualising feature wise attention scores/filters in a graph? And is this broadly possible?
What are the three keys in the state_dict accessed per the example below of a trained GATConv Layer?