How could I perserve some types of nodes when use MultiLayerFullNeighborSampler?

In the Heterogeneous Graph Network. When I use MultiLayerFullNeighborSampler to do mini-batch for node type ‘A’, In the last block I would like to preserve the node type ‘B’ which links with node type ‘A’ . What should I do? Many Thanks!

Do you mean you want to include all B nodes that are connected to the sampled A nodes? If so, you will need to write a new dataloader.

The fastest way to implement this is to subclass the dgl.dataloading.EdgeCollator since it already samples the MFGs (blocks) given a set of edges. All you need is to change the collate function to take in the node IDs instead of edge IDs first, and pass it to EdgeCollator.collate after you get the edges connecting from A to B yourself. Something like:

class YourCollator(dgl.dataloading.EdgeCollator):
    def collate(self, A_nodes):
        AB_edges = self.g.out_edges(A_nodes, form='eid').numpy()
        return super().collate([('AB', i) for i in AB_edges])

Then you can use a torch.utils.data.DataLoader that iterates over your node IDs but use YourCollator.collate as the collate function.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.