Hi everyone,
I have a question regarding the recent version updates (> v.0.4) and message passing/receiving along a graph. I think it is the most easiest to describe my problem with an example. Assume I have the following directed graph, in which every node only as one successor (direct child) but n parents, where n can be anything from zero to positive infinity (in practice something like 0-4).
I am interested to compute some output at node 4, for which I need to traverse the entire graph and pass information along the edges. At each edge, the incoming data is transformed and this transformation is dependent on some edge features. An example could be to pass the edge feature through a small neural network to compute a weight, which is then multiplied to the incoming data to produce a message for the destination node.
What I did so far was something like:
graph = ... # Graph as in the image
target_node = 4 # for this example our target is node 4
# a torch Tensor containing a vector/timeseries for each node [nodes, timesteps]
graph.ndata['x_node'] = ...
# node feature or e.g. one-hot-encoding [edges, n_features/edges]
graph.edata['x_edge'] = ...
# pass all edge features at once through the neural net to compute the edge weights
graph.apply_edges(func=self._edge_function)
# message function: transform outgoing src data with edge weights
graph.register_message_func(self._message_function)
# reduce function: add all incoming messages to the node data to produce a new outgoing data
graph.register_reduce_func(self._reduce_function)
srcs, dsts = graph.edges()
# traverse the graph using depth first from the roots to the target node
for edge in reversed(dgl.dfs_edges_generator(graph, target_node, reverse=True)):
# get ids of src and dst node for the particular edge
src = srcs[edge.item()]
dst = dsts[edge.item()]
# a node is only src to one other node. Thus, if a node is src in the dfs traversal, it means that all
# incoming messages are already compute. Therefore, as a first step, we add potential incoming
# messages with the node data itself to compute the outgoing data for that node.
graph.recv(src)
# use the outgoing data to compute the incoming message for the destination node from this source
graph.send((src, dst))
# the final step is to combine the incoming messages at the target node, since it was never a source node in the traversal
graph.recv(torch.tensor([target_node]))
I hope you can follow me so far.
This worked well with dgl < 0.5, where it was possible to call .send()
and .recv()
on a graph separately and in any given order.
After one month off, I tried to run this code with dgl v.0.6 (also tried v.0.5) but got a deprecation error that .send()
and .recv()
. E.g. for send()
it is:
DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges
API to compute messages as edge data. Then use DGLGraph.send_and_recv
and set the message function as dgl.function.copy_e to conduct message
aggregation.
My problem is: I do not always want to send something and immediately perform a receive operation for the destination node. If we look at the image from above for example, especially at node 3. Before I can perform the receive operation at node 3 (which will sum the mailbox items and the node data to produce a node data output) I need to send data from node 1 through edge (1,3) and from node 2 through edge (2,3). With the new API however, I can find a way to do that. Is it not possible anymore to send data to the inbox of a node and then eventually add another message to the mailbox, before finally calling a receive function? E.g. my reduce function from the code snippet above was
def _reduce_function(self, nodes: NodeBatch) -> Dict[str, torch.Tensor]:
# the outflow of a polygon is the sum over the incoming (routed) flows + the polygon discharge itself
return {'outgoing': nodes.data["x_node"] + torch.sum(nodes.mailbox["incoming"], dim=1)}
With the new send_and_recv
function, this would however immediately compute outgoing as soon as I pass a message from node 1 through edge (1,3) and then overwrite outgoing just with the incoming message from edge (2,3) in the next step since the inbox does not contain the old message anymore.
How can I solve this problem with the new API? I guess I am just to stupid to see the solution after 12 hours of digging through the new documentation. Would greatly appreciate any help. If my problem is not clear, please let me know and I tried to update the description.
Thanks in advance.