Hello! I am investigating how to use both node and edge features during message passing for a rather simple model such as GraphSAGE. I think the more direct way to implement it would be something like:
def edata_message_func(edges):
return {'m': torch.cat((edges.src['h'], edges.data['e']), 1)}
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1], [1, 2])
g.ndata['h'] = torch.tensor([[1.], [2.], [3.]])
g.edata['e'] = torch.tensor([[3., 5.], [4., 6.]])
g.update_all(edata_message_func,
dgl.function.mean('m', 'h'))
I believe that using something like this in the models would work (please correct me if I am wrong), however I am using a user-defined function which would not be optimized for spMV, so there will be some performance impact.
Do you think there is a way to use built-in functions for such message passing? It seems to me that the code below should work, do you agree?
g.update_all(fn.copy_e('e', 'm_e'),
fn.sum('m_e', 'h_e'))
g.update_all(fn.copy_u('h', 'm_n'),
fn.sum('m_n', 'h_n'))
g.ndata['h'] = torch.cat((g.ndata.pop('h_n'), g.ndata.pop('h_e')), 1)
BTW, do you believe I will encounter more problems trying to use DGL to message passing edge+node features for models such as GraphSAGE and GAT?
Thanks!