[Blog] Understand Graph Attention Network

Is there a way to obtain the EdgeBatch “edges” argument from the edge_attention function?
I am trying to set up a custom edge_attention function that also takes the edge type (etype) as an argument. But for that to work, I have to also pass “edges,” an EdgeBatch instance.
Please see below my attempt at this custom function. Thanks in advance.

def edge_attention(self, edges, edgetype):
# edge UDF for equation (2)
z_str = edgetype + ‘_z’
e_str = edgetype + ‘_e’
z2 = torch.cat([edges.src[z_str], edges.dst[z_str]], dim=1)
# forward attention layer
key = edgetype+ "att"
a = self.n_embeds
key
return {e_str: F.leaky_relu(a)}

Nevermind, you can just do g.apply_edges(attn_func, etype= etype) instead of making the attn_func cater to different edge types

EDIT: now I am stuck on how to make “a = self.attn_fc(z2)” portion of edge_attention function to cater to different edge types

You can have a separate attn_fc module for each edge type.

It seems that you are applying this on a heterogeneous graph. Please feel free to open a new thread if you would like to follow up. Thanks!

1 Like