I see that PyG has a similar implementation.
https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/meta.py
I wonder if DGL has a corresponding implementation class?
I want to use this model in dgl-lifesci. I don’t know how to implement it?
Hi,
I think that’s essentially identical to writing a custom message passing module.
For graph_nets’ NodeModel
, DGL has g.update_all(message_func, reduce_func)
for that purpose. In this case, message_func
is a function that computes messages per edge and reduce_func
is a function that aggregates incoming messages per node. DGL has some builtin message functions and reduce functions as shortcuts and means of speeding things up.
For graph_nets’ EdgeModel
, DGL has g.apply_edges(apply_func)
where apply_func
is a function that computes edge features per edge. The DGL builtin message functions can also be applied here.
For graph_nets’ GlobalModel
it is essentially readout for each graph in a batch. You can use dgl.batch()
to batch a list of graphs into one bigger graph, run update_all()
or apply_edges()
, and then use dgl.sum_nodes()
etc. to sum up node representations for each graph in the batch for instance, or use dgl.unbatch()
to revert it into a list of graphs for more complicated computations later.
Please feel free to follow up. It will be helpful if you could provide an example of how your EdgeModel
, NodeModel
and GlobalModel
would look like. Thanks!
At present, I don’t know how to implement it and whether it can be compatible with the torch_scatter of PyG?
At present, I don’t know how to implement it and whether it can be compatible with the torch_scatter of PyG?
I think @BarclayII’s exactly answers your question, for NodeModel/EdgeModel/GlobalModel we have corresponding APIs.
We are compatible with torch_scatter
so feel free to use it if you like, however our APIs are more efficient.