Is there a way to obtain the EdgeBatch “edges” argument from the edge_attention function?
I am trying to set up a custom edge_attention function that also takes the edge type (etype) as an argument. But for that to work, I have to also pass “edges,” an EdgeBatch instance.
Please see below my attempt at this custom function. Thanks in advance.
def edge_attention(self, edges, edgetype):
# edge UDF for equation (2)
z_str = edgetype + ‘_z’
e_str = edgetype + ‘_e’
z2 = torch.cat([edges.src[z_str], edges.dst[z_str]], dim=1)
# forward attention layer
key = edgetype+ "att"
a = self.n_embedskey
return {e_str: F.leaky_relu(a)}