Feed Graph to the User-defined Aggregation Function


Thank you for developing this wonderful tool. I would like to implement a user-defined aggregation function with learnable parameters. Now I need to feed the input graph to my user-defined message function and reduce function as an input parameter, something like graph.update_all(self.message_func, self.reduce_func(graph)). Is it possible to achieve this goal with DGL? I would really appreciate the help. Thank you.


You can do something as follows

def custom_udf(graph):
    def msg_func(edges):
        # Use the graph
    return msg_func

Also, have you checked the use of NodeBatch and EdgeBatch here?

1 Like

Hi @mufeili

Thank you for the quick response! I appreciate it. I think I got the point after reading your response and checking the use of NodeBatch and EdgeBatch. Thank you!