How to plot the attention weights...?

I want to draw a graph with attention weights like this:

https://docs.dgl.ai/tutorials/models/1_gnn/9_gat.html#visualizing-and-understanding-attention-learnt

Is there any example code for visualizing the GAT?
Thanks a lot!

1 Like

Below is a variant of the code I used for plotting attentions over graphs. You may treat it as a starting point for your own use.

import matplotlib.pyplot as plt
import networkx as nx

def plot(g, attention, ax, nodes_to_plot=None, nodes_labels=None,
         edges_to_plot=None, nodes_pos=None, nodes_colors=None,
         edge_colormap=plt.cm.Reds):
    """
    Visualize edge attentions by coloring edges on the graph.
    g: nx.DiGraph
        Directed networkx graph
    attention: list
        Attention values corresponding to the order of sorted(g.edges())
    ax: matplotlib.axes._subplots.AxesSubplot
        ax to be used for plot
    nodes_to_plot: list
        List of node ids specifying which nodes to plot. Default to
        be None. If None, all nodes will be plot.
    nodes_labels: list, numpy.array
        nodes_labels[i] specifies the label of the ith node, which will
        decide the node color on the plot. Default to be None. If None,
        all nodes will have the same canonical label. The nodes_labels
        should contain labels for all nodes to be plot.
    edges_to_plot: list of 2-tuples (i, j)
        List of edges represented as (source, destination). Default to
        be None. If None, all edges will be plot.
    nodes_pos: dictionary mapping int to numpy.array of size 2
        Default to be None. Specifies the layout of nodes on the plot.
    nodes_colors: list
        Specifies node color for each node class. Its length should be
        bigger than number of node classes in nodes_labels.
    edge_colormap: plt.cm
        Specifies the colormap to be used for coloring edges.
    """
    if nodes_to_plot is None:
        nodes_to_plot = sorted(g.nodes())
    if edges_to_plot is None:
        assert isinstance(g, nx.DiGraph), 'Expected g to be an networkx.DiGraph' \
                                          'object, got {}.'.format(type(g))
        edges_to_plot = sorted(g.edges())
    nx.draw_networkx_edges(g, nodes_pos, edgelist=edges_to_plot,
                           edge_color=attention, edge_cmap=edge_colormap,
                           width=2, alpha=0.5, ax=ax, edge_vmin=0,
                           edge_vmax=1)

    if nodes_colors is None:
        nodes_colors = sns.color_palette("deep", max(nodes_labels) + 1)

    nx.draw_networkx_nodes(g, nodes_pos, nodelist=nodes_to_plot, ax=ax, node_size=30,
                           node_color=[nodes_colors[nodes_labels[v - 1]] for v in nodes_to_plot],
                           with_labels=False, alpha=0.9)

fig, ax = plt.subplots()
plot(..., ax=ax, ...)
ax.set_axis_off()
sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
plt.colorbar(sm, fraction=0.046, pad=0.01)
plt.show()
3 Likes

Ok, thanks for your reply and I will try your code. :grinning:

It really helps. Thx!

Hi, thanks for your plot, but can you release a working example?

I keep getting a TypeError: 'NoneType' object is not subscriptable error after first converting the dgl to networkx using ntx = model.g.to_networkx() both node_to_plot and edges_to_plot give me sorted values. Any idea why this is happening?

can’t have nodes_pos=None - that is causing the error

https://networkx.github.io/documentation/stable/reference/drawing.html#layout

however, now I get:

RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

This is just what stated there, some input x is a torch tensor that requires gradient and you need to do x.detach().numpy().

[quote=“mufeili, post:2, topic:206”]
sorted(g.edges())
[/quote]i

Hi @mufeii. Thanks for posting this helpful guide to plotting attention weights. However, I’m struggling to get the attention param passed into plot() to work. I utilize a DGL molecular graph to featurize SMILES (using the smiles_to_bigraph method), before following your tutorial. Here is my block of code:

g_dgl = broad_graph[3]
print(type(g_dgl))
dg = g_dgl.to_networkx().to_directed()
print(type(dg))
dg.edges()
#Y.reshape(Y.shape[0])
grEd = dg.edges()
print(grEd)
fig, ax = plt.subplots()
plot(dg, sorted(grEd), ax=ax, nodes_pos=nx.random_layout(dg), nodes_labels=[0, 1, 2])
ax.set_axis_off()
sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
plt.colorbar(sm, fraction=0.046, pad=0.01)
plt.show()

I get the following error:

ValueError: RGBA sequence should have length 3 or 4

Any help is greatly appreciated. Thanks!

You did not pass a correct argument for attention in the plot function, which should be a list of numerical values in the range of [0, 1]. In particular, attention[i] gives the weight on edge sorted(dg.edges())[i].

Also, if you are only interested in molecules, you may want to visualize edge weights with RDKit. Check the Interpretation section.

Hello, thanks for your code! I have a simple question. Now I have trained a GAT network for molecular classification, but I don’t know how to pick attention weights of the last GATLayer.

Do you mean how to pick among multiple attention heads?

I am doing a GAT model for graph classification. I want to pick attention scores of the node and its neighbors. How can I do? Thanks for your reply.

I saw this URL,

, how can I do ?

Below you can pick a subgraph of Cora and plot the attention weights of the last GATLayer.

I just randomly sampled a subgraph for demo.

I am doing a GAT model for graph classification. I want to pick attention scores of the node and its neighbors. How can I do? Thanks for your reply.

If you are doing graph classification and your graphs are not very large, you may simply plot whole graphs with edges colored based on attention weights.