How to plot the attention weights...?

I want to draw a graph with attention weights like this:

Is there any example code for visualizing the GAT?
Thanks a lot!

1 Like

Below is a variant of the code I used for plotting attentions over graphs. You may treat it as a starting point for your own use.

import matplotlib.pyplot as plt
import networkx as nx

def plot(g, attention, ax, nodes_to_plot=None, nodes_labels=None,
         edges_to_plot=None, nodes_pos=None, nodes_colors=None,
    Visualize edge attentions by coloring edges on the graph.
    g: nx.DiGraph
        Directed networkx graph
    attention: list
        Attention values corresponding to the order of sorted(g.edges())
    ax: matplotlib.axes._subplots.AxesSubplot
        ax to be used for plot
    nodes_to_plot: list
        List of node ids specifying which nodes to plot. Default to
        be None. If None, all nodes will be plot.
    nodes_labels: list, numpy.array
        nodes_labels[i] specifies the label of the ith node, which will
        decide the node color on the plot. Default to be None. If None,
        all nodes will have the same canonical label. The nodes_labels
        should contain labels for all nodes to be plot.
    edges_to_plot: list of 2-tuples (i, j)
        List of edges represented as (source, destination). Default to
        be None. If None, all edges will be plot.
    nodes_pos: dictionary mapping int to numpy.array of size 2
        Default to be None. Specifies the layout of nodes on the plot.
    nodes_colors: list
        Specifies node color for each node class. Its length should be
        bigger than number of node classes in nodes_labels.
        Specifies the colormap to be used for coloring edges.
    if nodes_to_plot is None:
        nodes_to_plot = sorted(g.nodes())
    if edges_to_plot is None:
        assert isinstance(g, nx.DiGraph), 'Expected g to be an networkx.DiGraph' \
                                          'object, got {}.'.format(type(g))
        edges_to_plot = sorted(g.edges())
    nx.draw_networkx_edges(g, nodes_pos, edgelist=edges_to_plot,
                           edge_color=attention, edge_cmap=edge_colormap,
                           width=2, alpha=0.5, ax=ax, edge_vmin=0,

    if nodes_colors is None:
        nodes_colors = sns.color_palette("deep", max(nodes_labels) + 1)

    nx.draw_networkx_nodes(g, nodes_pos, nodelist=nodes_to_plot, ax=ax, node_size=30,
                           node_color=[nodes_colors[nodes_labels[v - 1]] for v in nodes_to_plot],
                           with_labels=False, alpha=0.9)

fig, ax = plt.subplots()
plot(..., ax=ax, ...)
sm =, norm=plt.Normalize(vmin=0, vmax=1))
plt.colorbar(sm, fraction=0.046, pad=0.01)

Ok, thanks for your reply and I will try your code. :grinning:

It really helps. Thx!

Hi, thanks for your plot, but can you release a working example?

I keep getting a TypeError: 'NoneType' object is not subscriptable error after first converting the dgl to networkx using ntx = model.g.to_networkx() both node_to_plot and edges_to_plot give me sorted values. Any idea why this is happening?

can’t have nodes_pos=None - that is causing the error

however, now I get:

RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

This is just what stated there, some input x is a torch tensor that requires gradient and you need to do x.detach().numpy().