Learning edges between nodes

Is there a differentiable way to add or remove edges between nodes?
For some background, if it helps, I have a task where I am classifying entire graphs. I have trained a useful model to classify these graphs, but the actual edges in the graphs cannot be entirely trusted. I’ve found that the edges can be changed in a particular graph to improve the model’s prediction on the graph. One could imagine training another model that outputs edge existence probabilities that are used to set the edges of the input graphs that are sent to the original model and a loss is evaluated based on this output. The issue that I’ve run into is that I need a way to differentiably add or remove edges based on these probabilities so that the backpropagation can continue through the edge selection model.

1 Like

You can do something as follows.

import dgl.function as fn
import torch

# Assume g is a DGLGraph and g.ndata['h'] stores input node features
g.edata['w'] = torch.ones(g.num_edges(), requires_grad=True)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))

Hey Mufei, thanks for the reply. Sorry, I should have included a simple example because I am not very good at DGL and so got lost immediately.

Supposed we have a graph defined as:

u = [0,1,2,3,1,2,3,0]
v = [1,2,3,0,0,1,2,3]
g = dgl.DGLGraph((u, v))
node_feats = torch.tensor([[1.,0.],[0.,1.],[1.,0.],[0.,1.]], requires_grad=True)
edge_feats = torch.tensor([[1.,0.],[0.,1.],[1.,0.],[0.,1.],[1.,0.],[0.,1.],[1.,0.],[0.,1.]], requires_grad=True)
g.ndata['hv'] = node_feats
g.edata['he'] = edge_feats

Which would have an adjacency matrix like this:

[[0., 1., 0., 1.],
 [1., 0., 1., 0.],
 [0., 1., 0., 1.],
 [1., 0., 1., 0.]]

Now I have some model that has decided that the graph’s new adjacency matrix should be:

[[0., 1., 0., 0.],
 [1., 0., 1., 0.],
 [0., 1., 0., 1.],
 [0., 0., 1., 0.]]

Eg removing the edge between node 0 and node 3, and deleting the edge features associated with it. Although in general it may want to create new edges (for my specific problem all edges created in this way or deleted in this way have the same features, but there are other edges that aren’t messed with that have different edge features).

Now let me try to understand your code:

g.edata[‘w’] = torch.ones(g.num_edges(), requires_grad=True)

Here you are generating edge features.

g.update_all(fn.u_mul_e(‘h’, ‘w’, ‘m’), fn.sum(‘m’, ‘h_new’))

fn.u_mul_e(‘h’, ‘w’, ‘m’) generates messages by multiplying the node features with the edge features, fn.sum(‘m’, ‘h_new’)) sums these messages together to get new node features and g.update_all() will update the node features to be the newly found node features.
It’s not clear to me how this can be used to add or remove edges. Sorry for not posting an example earlier, could you show how to use update the edges from the old adjacency matrix to the new one that I have plotted above so that I can follow your solution more clearly?
Thank you for your time.

g.edata[‘w’] = torch.ones(g.num_edges(), requires_grad=True)

generates learnable weights/masks over existing edges. With

g.update_all(fn.u_mul_e(‘h’, ‘w’, ‘m’), fn.sum(‘m’, ‘h_new’))

you multiply the source node features by the edge weights in message passing. If the edge weight is 1, then this will be the same as message passing over the old graph. If the edge weight degrades to 0, then essentially the edge is “deleted”. You may want to apply sigmoid to ensure that the edge weights are strictly in [0, 1] though.

Thanks for the reply! I’ll try that out soon, but for now, how would you extend that to adding edges?

One possibility is to create a complete graph where all pairs of nodes are connected. Depending on the particular dataset, this may be too costly or not. Another possibility is to randomly add new edges and record the new edges that get high weights. The latter one might require some innovation in algorithms.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.