How do you log the attention weights for a minbatched graphs?

So, the dgl uses MFG for minbatching, and what I am not sure about is how to attribute attention weights returned from GATConv layer from these batched nodes MFGs back into original graph edges?
I am currently using the inference function such as this for GAT
Thank you

So, I noticed in block doc, “This indicates that edge (2, 3) and (1, 2) are included in the result graph.”
how are those two edges stated? is there an ordering of EID to induced nodes in mfg, I believe we can’t get the global edge ordering unless this is done:

src, dst = block.edges(order='eid')
induced_src[src], induced_dst[dst]

I did this so far, not sure whether I am mistaken somewhere,

def mapping_edges(edges):
    edges = np.row_stack((edges[0].numpy(), edges[1].numpy()))
    return {(j,k):i for i,(j,k) in enumerate(zip(edges[0,:], edges[1,:])) }
edge_mapping = mapping_edges(g.edges())

# 8 layer1 heads, 1 layer2 head
attentions = torch.zeros(g.number_of_edges(), 9, device = device)

# with dataloader and batched graphs inference on MFG blocks

src, dst = block.edges(order='eid')
induced_src = block.srcdata[dgl.NID]
induced_dst = block.dstdata[dgl.NID]

att_edge_order = [edge_mapping[(i,j)] for i,j in zip(induced_src[src].cpu().numpy(), induced_dst[dst].cpu().numpy())]

h, att = gat_conv_layer(block, h, get_attention)
#for layer1 - 8 heads attention logging
attentions[att_edge_order, :-1] = att.squeeze()

GATConv has an argument get_attention which returns the attention scores on each edge. To put into the original graph, the MFGs’ edata[dgl.EID] contains the edge IDs in the original graph.

If I am not wrong, block.edata[dgl.EID] is remapped with respect to each MFG and so starts with 0 for every batched MFG, how to relate it back to original graph edge order?
Can you please check the code snippet above? where I used dgl.NID to get back the mapping and whether it is ok?

Starting with DGL 0.8 blocks[0].edata[dgl.EID] contains the edge ID to the original graph. Which version were you using?

Oh, thanks. I am using 0.7
will switch to 0.8

for those on earlier versions, this code’s att_edge_order is equaivalent to block.edata[dgl.EID]

1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.