I’ve modified your code a bit and see the code below for an example.

```
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import torch
from dgl import DGLGraph
def plot(g, attention, ax, nodes_to_plot=None, nodes_labels=None,
edges_to_plot=None, nodes_pos=None, nodes_colors=None,
edge_colormap=plt.cm.Reds):
if nodes_to_plot is None:
nodes_to_plot = sorted(g.nodes())
if edges_to_plot is None:
assert isinstance(g, nx.DiGraph), 'Expected g to be an networkx.DiGraph' \
'object, got {}.'.format(type(g))
edges_to_plot = sorted(g.edges())
nx.draw_networkx_edges(g, nodes_pos, edgelist=edges_to_plot,
edge_color=attention, edge_cmap=edge_colormap,
width=2, alpha=0.5, ax=ax, edge_vmin=0,
edge_vmax=1)
if nodes_colors is None:
nodes_colors = sns.color_palette("deep", max(nodes_labels) + 1)
nx.draw_networkx_nodes(g, nodes_pos, nodelist=nodes_to_plot, ax=ax, node_size=30,
node_color=[nodes_colors[nodes_labels[v - 1]] for v in nodes_to_plot],
with_labels=False, alpha=0.9)
g_dgl = DGLGraph([(0, 1), (1, 2), (1, 3)])
print(type(g_dgl))
dg = g_dgl.to_networkx().to_directed()
print(type(dg))
dg.edges()
#Y.reshape(Y.shape[0])
grEd = dg.edges()
print(grEd)
fig, ax = plt.subplots()
plot(dg, [0.2, 0.6, 0.6], ax=ax, nodes_pos=nx.random_layout(dg), nodes_labels=[0, 1, 2])
ax.set_axis_off()
sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
plt.colorbar(sm, fraction=0.046, pad=0.01)
plt.show()
```