I’ve modified your code a bit and see the code below for an example.
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import torch
from dgl import DGLGraph
def plot(g, attention, ax, nodes_to_plot=None, nodes_labels=None,
edges_to_plot=None, nodes_pos=None, nodes_colors=None,
edge_colormap=plt.cm.Reds):
if nodes_to_plot is None:
nodes_to_plot = sorted(g.nodes())
if edges_to_plot is None:
assert isinstance(g, nx.DiGraph), 'Expected g to be an networkx.DiGraph' \
'object, got {}.'.format(type(g))
edges_to_plot = sorted(g.edges())
nx.draw_networkx_edges(g, nodes_pos, edgelist=edges_to_plot,
edge_color=attention, edge_cmap=edge_colormap,
width=2, alpha=0.5, ax=ax, edge_vmin=0,
edge_vmax=1)
if nodes_colors is None:
nodes_colors = sns.color_palette("deep", max(nodes_labels) + 1)
nx.draw_networkx_nodes(g, nodes_pos, nodelist=nodes_to_plot, ax=ax, node_size=30,
node_color=[nodes_colors[nodes_labels[v - 1]] for v in nodes_to_plot],
with_labels=False, alpha=0.9)
g_dgl = DGLGraph([(0, 1), (1, 2), (1, 3)])
print(type(g_dgl))
dg = g_dgl.to_networkx().to_directed()
print(type(dg))
dg.edges()
#Y.reshape(Y.shape[0])
grEd = dg.edges()
print(grEd)
fig, ax = plt.subplots()
plot(dg, [0.2, 0.6, 0.6], ax=ax, nodes_pos=nx.random_layout(dg), nodes_labels=[0, 1, 2])
ax.set_axis_off()
sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
plt.colorbar(sm, fraction=0.046, pad=0.01)
plt.show()