In the Heterogeneous Graph Network. When I use MultiLayerFullNeighborSampler to do mini-batch for node type ‘A’, In the last block I would like to preserve the node type ‘B’ which links with node type ‘A’ . What should I do? Many Thanks!
Do you mean you want to include all B nodes that are connected to the sampled A nodes? If so, you will need to write a new dataloader.
The fastest way to implement this is to subclass the
dgl.dataloading.EdgeCollator since it already samples the MFGs (blocks) given a set of edges. All you need is to change the collate function to take in the node IDs instead of edge IDs first, and pass it to
EdgeCollator.collate after you get the edges connecting from A to B yourself. Something like:
class YourCollator(dgl.dataloading.EdgeCollator): def collate(self, A_nodes): AB_edges = self.g.out_edges(A_nodes, form='eid').numpy() return super().collate([('AB', i) for i in AB_edges])
Then you can use a
torch.utils.data.DataLoader that iterates over your node IDs but use
YourCollator.collate as the collate function.