I want to draw a graph with attention weights like this：

https://docs.dgl.ai/tutorials/models/1_gnn/9_gat.html#visualizing-and-understanding-attention-learnt

Is there any example code for visualizing the GAT?

Thanks a lot!

I want to draw a graph with attention weights like this：

https://docs.dgl.ai/tutorials/models/1_gnn/9_gat.html#visualizing-and-understanding-attention-learnt

Is there any example code for visualizing the GAT?

Thanks a lot!

Below is a variant of the code I used for plotting attentions over graphs. You may treat it as a starting point for your own use.

```
import matplotlib.pyplot as plt
import networkx as nx
def plot(g, attention, ax, nodes_to_plot=None, nodes_labels=None,
edges_to_plot=None, nodes_pos=None, nodes_colors=None,
edge_colormap=plt.cm.Reds):
"""
Visualize edge attentions by coloring edges on the graph.
g: nx.DiGraph
Directed networkx graph
attention: list
Attention values corresponding to the order of sorted(g.edges())
ax: matplotlib.axes._subplots.AxesSubplot
ax to be used for plot
nodes_to_plot: list
List of node ids specifying which nodes to plot. Default to
be None. If None, all nodes will be plot.
nodes_labels: list, numpy.array
nodes_labels[i] specifies the label of the ith node, which will
decide the node color on the plot. Default to be None. If None,
all nodes will have the same canonical label. The nodes_labels
should contain labels for all nodes to be plot.
edges_to_plot: list of 2-tuples (i, j)
List of edges represented as (source, destination). Default to
be None. If None, all edges will be plot.
nodes_pos: dictionary mapping int to numpy.array of size 2
Default to be None. Specifies the layout of nodes on the plot.
nodes_colors: list
Specifies node color for each node class. Its length should be
bigger than number of node classes in nodes_labels.
edge_colormap: plt.cm
Specifies the colormap to be used for coloring edges.
"""
if nodes_to_plot is None:
nodes_to_plot = sorted(g.nodes())
if edges_to_plot is None:
assert isinstance(g, nx.DiGraph), 'Expected g to be an networkx.DiGraph' \
'object, got {}.'.format(type(g))
edges_to_plot = sorted(g.edges())
nx.draw_networkx_edges(g, nodes_pos, edgelist=edges_to_plot,
edge_color=attention, edge_cmap=edge_colormap,
width=2, alpha=0.5, ax=ax, edge_vmin=0,
edge_vmax=1)
if nodes_colors is None:
nodes_colors = sns.color_palette("deep", max(nodes_labels) + 1)
nx.draw_networkx_nodes(g, nodes_pos, nodelist=nodes_to_plot, ax=ax, node_size=30,
node_color=[nodes_colors[nodes_labels[v - 1]] for v in nodes_to_plot],
with_labels=False, alpha=0.9)
fig, ax = plt.subplots()
plot(..., ax=ax, ...)
ax.set_axis_off()
sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
plt.colorbar(sm, fraction=0.046, pad=0.01)
plt.show()
```

3 Likes

Ok, thanks for your reply and I will try your code.

It really helps. Thx!

Hi, thanks for your plot, but can you release a working example?