# Differentiable propagation

Hi!

I have a DAG where each edge has a trainable weight. I need to propagate a message in topological order using differentiable aggregation/reduce functions and then compute a loss at the output and get its gradient w.r.t. the edge weights. Basically, it’s kind of like forward/backward on a neural network. The graph structure is static but very large (a few hundred millions of nodes and edges), and I need to do this operation many times (a few thousand).

I have tried doing this using the UDF message and reduce functions but it’s pretty slow on deep graphs:

``````def message_func(edges):
# G is a custom differentiable function
return {"M": G(edges.src["M"], edges.data["W"])}

def reduce_func(nodes):
# F is a custom differentiable function
return {"M": F(nodes.mailbox["M"])}

graph = ...
order = dgl.topological_nodes_generator(graph)
graph.ndata["M"] = torch.zeros(graph.num_nodes(), device="cuda")
startpoints = graph.in_degrees() == 0
endpoints = graph.out_degrees() == 0

# in a loop
graph.ndata["M"].detach_()
# setup some init value
graph.ndata["M"][startpoints] = ...
graph.edata["W"] = ...

# propagate
graph.prop_nodes(order[1:], message_func, reduce_func)

# compute loss and backward
loss = L(graph.ndata["M"][endpoints])
loss.backward()

# do something with the gradient