Using node and edge features in message passing

Hello! I am investigating how to use both node and edge features during message passing for a rather simple model such as GraphSAGE. I think the more direct way to implement it would be something like:

def edata_message_func(edges):
    return {'m':['h'],['e']), 1)}

g = dgl.DGLGraph()
g.add_edges([0, 1], [1, 2])
g.ndata['h'] = torch.tensor([[1.], [2.], [3.]])
g.edata['e'] = torch.tensor([[3., 5.], [4., 6.]])
              dgl.function.mean('m', 'h'))

I believe that using something like this in the models would work (please correct me if I am wrong), however I am using a user-defined function which would not be optimized for spMV, so there will be some performance impact.

Do you think there is a way to use built-in functions for such message passing? It seems to me that the code below should work, do you agree?

g.update_all(fn.copy_e('e', 'm_e'), 
             fn.sum('m_e', 'h_e'))
g.update_all(fn.copy_u('h', 'm_n'), 
             fn.sum('m_n', 'h_n'))
g.ndata['h'] ='h_n'), g.ndata.pop('h_e')), 1)

BTW, do you believe I will encounter more problems trying to use DGL to message passing edge+node features for models such as GraphSAGE and GAT?


Yes you are right, btw, we support a list of message/reduce functions, which means:

g.update_all([fn.copy_e('e', 'm_e'), fn.copy_u('h', 'm_n')], 
             [fn.sum('m_e', 'h_e'), fn.sum('m_n', 'h_n')])

is valid in DGL.

I think most operations can be formulated in combination of built-in functions. If not, you can use udf as a fallback solution.

1 Like

That’s great! Thanks for your help!

Great! That’s what I just want!