How to remove nodes from batched graph

Hi all, I have a batched graph and I needed to remove some nodes from this graph. However, after I removed some nodes, I found that I can’t unbatch these graphs. Is there a way to solve this, apart from unbatching the graphs every time before I remove the nodes?

Currently you can only unbatch the graphs every time before removing the nodes. Alternatively, you can construct a binary mask for masking out removed nodes during the computation.

Hi, thanks for the reply. I’ve looked up masks, but I don’t seem to find any documentations on it. Can you point me to how to mask the graph in a way that is equivalent to removing the node?

You can simply multiply the node features by a binary tensor indicating whether the node is masked.

1 Like

Hi, just a quick followup, how exactly do you do this? For example, if you have a batched graph g and feature matrix x, you can do something like AvgPool()(g, x*mask), where mask is a Nx1 boolean tensor?

See if the example below makes sense to you.

import dgl
import torch
from dgl.nn.pytorch import SumPooling

g1 = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))
g1.ndata['h'] = torch.arange(3).reshape(-1, 1).float()
g1.ndata['mask'] = torch.tensor([[0.], [1.], [1.]])
g2 = dgl.graph((torch.tensor([0, 1]), torch.tensor([0, 1])))
g2.ndata['h'] = torch.arange(2).reshape(-1, 1).float()
g2.ndata['mask'] = torch.tensor([[1.], [0.]])
bg = dgl.batch([g1, g2])
sumpool = SumPooling()
h_sum = sumpool(bg, bg.ndata['h'])
num_nodes_exist = sumpool(bg, bg.ndata['mask'])
g_repr = h_sum / num_nodes_exist

Not sure if I understand it correctly, is ‘mask’ a reserved word and dgl will automatically ignore masked nodes when sumpooling ndata[‘h’]? Because it doesn’t look like we are masking the nodes for ‘h’ anywhere?

DGL’s readout functions accept an optional argument for weighting features in readout. See the doc.

I see. This is very helpful, thank you! Two last questions:

  1. Is there a way to mask the node during convolution? If not masked it will still send out messages to neighbors right?
  2. Can similar things be done with GlobalAttentionPooling? I don’t think there is a weight argument to it?
  1. Is there a way to mask the node during convolution? If not masked it will still send out messages to neighbors right?

You can either multiply the node features by a node mask before message passing or multiply the messages by an edge mask during the message passing.

  1. Can similar things be done with GlobalAttentionPooling? I don’t think there is a weight argument to it?

You can multiply the node features by the mask before passing them to GlobalAttentionPooling.

I now get it. Since multiplying the feature by 0 will zero it out(duh)? Was being dumb, thanks again for such a long post. Really appreciated your help!