Unidirectional Bipartite Graphs (Node Sampling)

I am trying to model a unidirectional bipartite graph (dgl==0.6.1) where there are basically one hop neighborhoods all separated. I would like to iterate through the dst nodes with a node data loader but I am running into some issues.

image

Here is the graph setup:
We have 3 dst nodes that each have a collection of src nodes pointing to them.

# Src and dst nodes & create graph
src = torch.tensor([0, 0, 1,  2, 3, 4, 5, 6, 6])
dst =  torch.tensor([0,0,0,  1, 2, 2, 2, 2, 2])
g = dgl.heterograph({('src_type', 'edge', 'dst_type') : (src, dst)})

# Generate node features and labels (dst_nodes have different number of features)
g.srcdata["node_feat"] = torch.rand(len(g.srcnodes()), 10)
g.dstdata["node_feat"] = torch.rand(len(g.dstnodes()), 3)
g.dstdata["label"] = torch.randint(0, 2, (len(g.dstnodes()),1))

When I feed in the full graph I get the expected result which is 3 node embeddings from graph sage:

conv = dgl.nn.SAGEConv((10 , 3), 8, aggregator_type = "mean")
fc = nn.Linear(8, 1)
h= conv(g, (g.srcdata["node_feat"], g.dstdata["node_feat"]))
print(h)

tensor([[ 0.4699, -0.0037,  0.0341,  1.3000,  0.5906,  0.9941, -0.2148, -0.8874],
        [-0.0616,  0.4049,  0.0663,  0.7090,  0.4139,  1.4098, -0.3871, -0.3437],
        [ 0.2171, -0.6719,  0.5209,  1.3910,  0.9904,  1.0238,  0.6527, -0.3474]],
       grad_fn=<AddBackward0>)

However when I try to sample the full neighborhoods, I run into some error:

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dl = dgl.dataloading.NodeDataLoader(g, {"dst_type": torch.tensor([0,1,2])}, block_sampler=sampler, batch_size = 2)
for input_nodes, output_nodes, blocks in dl:
    print(input_nodes)
    print(output_nodes)
    print(blocks)
    h = conv(blocks[0], (blocks[0].srcdata["node_feat"]["src_type"], blocks[0].dstdata["node_feat"]["dst_type"]))
    print(h)
{'dst_type': tensor([0, 1]), 'src_type': tensor([0, 1, 2])}
{'dst_type': tensor([0, 1]), 'src_type': tensor([], dtype=torch.int64)}
[Block(num_src_nodes={'dst_type': 2, 'src_type': 3},
      num_dst_nodes={'dst_type': 2, 'src_type': 0},
      num_edges={('src_type', 'edge', 'dst_type'): 4},
      metagraph=[('src_type', 'dst_type', 'edge')])]
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/var/folders/28/14kmqy7j34x7cybycs_z3wq80000gp/T/ipykernel_49432/60589171.py in <module>
      3     print(output_nodes)
      4     print(blocks)
----> 5     conv(blocks[0], (blocks[0].srcdata["node_feat"]["src_type"], blocks[0].dstdata["node_feat"]["dst_type"]))
      6     logits = fc(h)
      7     loss = F.binary_cross_entropy_with_logits(logits, g.dstdata["fpf"].float())

~/opt/miniconda3/envs/torch-graph/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/opt/miniconda3/envs/torch-graph/lib/python3.8/site-packages/dgl/nn/pytorch/conv/sageconv.py in forward(self, graph, feat, edge_weight)
    228             # Message Passing
    229             if self._aggre_type == 'mean':
--> 230                 graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
    231                 graph.update_all(msg_fn, fn.mean('m', 'neigh'))
    232                 h_neigh = graph.dstdata['neigh']

~/opt/miniconda3/envs/torch-graph/lib/python3.8/site-packages/dgl/view.py in __setitem__(self, key, val)
     68     def __setitem__(self, key, val):
     69         if isinstance(self._ntype, list):
---> 70             assert isinstance(val, dict), \
     71                 'Current HeteroNodeDataView has multiple node types, ' \
     72                 'please passing the node type and the corresponding data through a dict.'

AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.

It looks like part of the issue is the block has registered src_type as a dsttype now:

print(blocks[0].dsttypes)
print(g.dsttypes)

['dst_type', 'src_type']
['dst_type']

Have you tried dgl.in_subgraph instead? For example,

sg = dgl.in_subgraph(g, {'dst_type': 0}, relabel_nodes=True)