I wrote folloing codes to apply GAT in my model but I found the running was stuck at message_func, I waited more than 12 hours but still no response.
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
print('msg')
return {'z': edges.src['z'], 'e': edges.data['e']}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
print('rdc')
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h': h}
def forward(self, x, g):
print("1")
self.g_tmp = g.local_var()
x_tangent = self.manifold.logmap0(x, c=self.c)
self.g_tmp.ndata['z'] = x_tangent
print("2")
self.g_tmp.apply_edges(self.edge_attention)
print("3")
print(self.g_tmp.num_nodes(), self.g_tmp.ndata['z'].shape, self.g_tmp.ndata['z'])
print(self.g_tmp.num_edges(), self.g_tmp.edata['e'].shape, self.g_tmp.edata['e'])
self.g_tmp.update_all(message_func=self.message_func, reduce_func=self.reduce_func)
print("4")
h = self.g_tmp.ndata.pop('h')
and the print result are as follows:
`
1
2
3
19717
torch.Size([19717, 32])
tensor([[ 0.0043, -0.0183, -0.0371, …, -0.0039, 0.0298, -0.0170],
[ 0.0599, 0.0057, -0.0073, …, -0.0088, 0.0233, 0.0396],
[ 0.0013, -0.0271, 0.0047, …, 0.0473, -0.0359, -0.0024],
…,
[ 0.0205, -0.0751, 0.0386, …, 0.0080, -0.0267, 0.0034],
[ 0.0254, -0.0522, 0.0202, …, 0.0161, -0.0832, 0.0664],
[ 0.0222, -0.0242, -0.0389, …, 0.0055, 0.0194, 0.0421]],
device=‘cuda:0’, grad_fn=)
49219
torch.Size([49219, 1])
tensor([[ 0.0697],
[-0.0006],
[-0.0009],
…,
[-0.0003],
[ 0.0134],
[ 0.1184]], device=‘cuda:0’, grad_fn=)
msg
Since ‘4’ and ‘rdc’ didn’t print, It seems the reduce func didn’t work and the data needed in message_func are correct, so I can’t figure out what’s the problem.