Differentiable propagation


I have a DAG where each edge has a trainable weight. I need to propagate a message in topological order using differentiable aggregation/reduce functions and then compute a loss at the output and get its gradient w.r.t. the edge weights. Basically, it’s kind of like forward/backward on a neural network. The graph structure is static but very large (a few hundred millions of nodes and edges), and I need to do this operation many times (a few thousand).

I have tried doing this using the UDF message and reduce functions but it’s pretty slow on deep graphs:

def message_func(edges):
    # G is a custom differentiable function
    return {"M": G(edges.src["M"], edges.data["W"])}

def reduce_func(nodes):
    # F is a custom differentiable function
    return {"M": F(nodes.mailbox["M"])}

graph = ...
order = dgl.topological_nodes_generator(graph)
graph.edata["W"] = torch.zeros(graph.num_edges(), requires_grad=True, device="cuda")
graph.ndata["M"] = torch.zeros(graph.num_nodes(), device="cuda")
startpoints = graph.in_degrees() == 0
endpoints = graph.out_degrees() == 0

# in a loop
graph.edata["W"].grad = None
with torch.no_grad():
    # setup some init value
    graph.ndata["M"][startpoints] = ...
    graph.edata["W"] = ...

# propagate
graph.prop_nodes(order[1:], message_func, reduce_func)

# compute loss and backward
loss = L(graph.ndata["M"][endpoints])

# do something with the gradient

Do you have any advice for speedup? Is it possible to code our own CUDA kernels for message/reduce functions?


I think the critical point here is that you need to propagate messages in topological order. If this is mandatory, then it is at least linear the increase of computation time with increasing depths. you have to compute them sequentially layer by layer due to the dependency. What is the depth of your static graph?

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.