Message_func stop running in GAT

I wrote folloing codes to apply GAT in my model but I found the running was stuck at message_func, I waited more than 12 hours but still no response.

def edge_attention(self, edges):
    # edge UDF for equation (2)
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e': F.leaky_relu(a)}

def message_func(self, edges):
    # message UDF for equation (3) & (4)
    print('msg')
    return {'z': edges.src['z'], 'e': edges.data['e']}

def reduce_func(self, nodes):
    # reduce UDF for equation (3) & (4)
    # equation (3)
    print('rdc')
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    # equation (4)
    h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h': h}

def forward(self, x, g):
    print("1")
    self.g_tmp = g.local_var()      
    x_tangent = self.manifold.logmap0(x, c=self.c)
    self.g_tmp.ndata['z'] = x_tangent
    print("2")
    self.g_tmp.apply_edges(self.edge_attention)
    print("3")
    print(self.g_tmp.num_nodes(), self.g_tmp.ndata['z'].shape, self.g_tmp.ndata['z'])
    print(self.g_tmp.num_edges(), self.g_tmp.edata['e'].shape, self.g_tmp.edata['e'])
    self.g_tmp.update_all(message_func=self.message_func, reduce_func=self.reduce_func)
    print("4")
    h = self.g_tmp.ndata.pop('h')

and the print result are as follows:
`

1
2
3
19717
torch.Size([19717, 32])
tensor([[ 0.0043, -0.0183, -0.0371, …, -0.0039, 0.0298, -0.0170],
[ 0.0599, 0.0057, -0.0073, …, -0.0088, 0.0233, 0.0396],
[ 0.0013, -0.0271, 0.0047, …, 0.0473, -0.0359, -0.0024],
…,
[ 0.0205, -0.0751, 0.0386, …, 0.0080, -0.0267, 0.0034],
[ 0.0254, -0.0522, 0.0202, …, 0.0161, -0.0832, 0.0664],
[ 0.0222, -0.0242, -0.0389, …, 0.0055, 0.0194, 0.0421]],
device=‘cuda:0’, grad_fn=)
49219
torch.Size([49219, 1])
tensor([[ 0.0697],
[-0.0006],
[-0.0009],
…,
[-0.0003],
[ 0.0134],
[ 0.1184]], device=‘cuda:0’, grad_fn=)
msg

Since ‘4’ and ‘rdc’ didn’t print, It seems the reduce func didn’t work and the data needed in message_func are correct, so I can’t figure out what’s the problem.

Works for me:

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.attn_fc = nn.Linear(d * 2, 1)
        
    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        print('msg')
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        print('rdc')
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, x, g):
        print("1")
        self.g_tmp = g.local_var()      
        #x_tangent = self.manifold.logmap0(x, c=self.c)
        self.g_tmp.ndata['z'] = x
        print("2")
        self.g_tmp.apply_edges(self.edge_attention)
        print("3")
        print(self.g_tmp.num_nodes(), self.g_tmp.ndata['z'].shape, self.g_tmp.ndata['z'])
        print(self.g_tmp.num_edges(), self.g_tmp.edata['e'].shape, self.g_tmp.edata['e'])
        self.g_tmp.update_all(message_func=self.message_func, reduce_func=self.reduce_func)
        print("4")
        h = self.g_tmp.ndata.pop('h')
        return h

net = Net(20).to('cuda')
s = torch.randint(0, 2000, (10000,))
d = torch.randint(0, 2000, (10000,))
ss = torch.cat([s, d])
dd = torch.cat([d, s])
g = dgl.graph((ss, dd), num_nodes=2000).to('cuda')
x = torch.randn(2000, 20).to('cuda')

net(x, g)

Output:

1
2
3
2000 torch.Size([2000, 20]) tensor([[-0.5073,  0.0447,  0.4622,  ...,  0.2164,  0.8741, -1.4115],
        [ 1.4964, -0.6950, -0.4422,  ..., -0.6878, -1.8845, -1.0926],
        [-2.4776, -0.9497,  0.1368,  ...,  1.4140, -1.4630,  0.0765],
        ...,
        [-0.8655, -1.9890,  1.8221,  ..., -0.4927, -0.1099,  0.1564],
        [-0.7141,  0.5261, -0.4841,  ...,  0.4287,  0.8080,  0.3688],
        [-1.2275, -0.7392,  1.6882,  ..., -0.9178, -0.4787,  0.7999]],
       device='cuda:0')
20000 torch.Size([20000, 1]) tensor([[ 0.0785],
        [-0.0017],
        [-0.0060],
        ...,
        [-0.0018],
        [-0.0074],
        [ 0.1303]], device='cuda:0', grad_fn=<LeakyReluBackward0>)
msg
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
rdc
4

Which DGL/PyTorch/CUDA version are you using?

FYI DGL also has a builtin dgl.nn.GATConv module for your purpose which is much faster.

dgl 0.5.3 pytorch 1.3.0 cuda10.0
I have some other operation in the first step to get WH,so I have to write this way.

Could you upgrade PyTorch to 1.6+? DGL 0.5+ doesn’t work well with PyTorch 1.5-.

thanks, I update pytorch to 1.7 and it works…

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.