I am trying to model a unidirectional bipartite graph (dgl==0.6.1) where there are basically one hop neighborhoods all separated. I would like to iterate through the dst nodes with a node data loader but I am running into some issues.
Here is the graph setup:
We have 3 dst nodes that each have a collection of src nodes pointing to them.
# Src and dst nodes & create graph
src = torch.tensor([0, 0, 1, 2, 3, 4, 5, 6, 6])
dst = torch.tensor([0,0,0, 1, 2, 2, 2, 2, 2])
g = dgl.heterograph({('src_type', 'edge', 'dst_type') : (src, dst)})
# Generate node features and labels (dst_nodes have different number of features)
g.srcdata["node_feat"] = torch.rand(len(g.srcnodes()), 10)
g.dstdata["node_feat"] = torch.rand(len(g.dstnodes()), 3)
g.dstdata["label"] = torch.randint(0, 2, (len(g.dstnodes()),1))
When I feed in the full graph I get the expected result which is 3 node embeddings from graph sage:
conv = dgl.nn.SAGEConv((10 , 3), 8, aggregator_type = "mean")
fc = nn.Linear(8, 1)
h= conv(g, (g.srcdata["node_feat"], g.dstdata["node_feat"]))
print(h)
tensor([[ 0.4699, -0.0037, 0.0341, 1.3000, 0.5906, 0.9941, -0.2148, -0.8874],
[-0.0616, 0.4049, 0.0663, 0.7090, 0.4139, 1.4098, -0.3871, -0.3437],
[ 0.2171, -0.6719, 0.5209, 1.3910, 0.9904, 1.0238, 0.6527, -0.3474]],
grad_fn=<AddBackward0>)
However when I try to sample the full neighborhoods, I run into some error:
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dl = dgl.dataloading.NodeDataLoader(g, {"dst_type": torch.tensor([0,1,2])}, block_sampler=sampler, batch_size = 2)
for input_nodes, output_nodes, blocks in dl:
print(input_nodes)
print(output_nodes)
print(blocks)
h = conv(blocks[0], (blocks[0].srcdata["node_feat"]["src_type"], blocks[0].dstdata["node_feat"]["dst_type"]))
print(h)
{'dst_type': tensor([0, 1]), 'src_type': tensor([0, 1, 2])}
{'dst_type': tensor([0, 1]), 'src_type': tensor([], dtype=torch.int64)}
[Block(num_src_nodes={'dst_type': 2, 'src_type': 3},
num_dst_nodes={'dst_type': 2, 'src_type': 0},
num_edges={('src_type', 'edge', 'dst_type'): 4},
metagraph=[('src_type', 'dst_type', 'edge')])]
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/var/folders/28/14kmqy7j34x7cybycs_z3wq80000gp/T/ipykernel_49432/60589171.py in <module>
3 print(output_nodes)
4 print(blocks)
----> 5 conv(blocks[0], (blocks[0].srcdata["node_feat"]["src_type"], blocks[0].dstdata["node_feat"]["dst_type"]))
6 logits = fc(h)
7 loss = F.binary_cross_entropy_with_logits(logits, g.dstdata["fpf"].float())
~/opt/miniconda3/envs/torch-graph/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~/opt/miniconda3/envs/torch-graph/lib/python3.8/site-packages/dgl/nn/pytorch/conv/sageconv.py in forward(self, graph, feat, edge_weight)
228 # Message Passing
229 if self._aggre_type == 'mean':
--> 230 graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
231 graph.update_all(msg_fn, fn.mean('m', 'neigh'))
232 h_neigh = graph.dstdata['neigh']
~/opt/miniconda3/envs/torch-graph/lib/python3.8/site-packages/dgl/view.py in __setitem__(self, key, val)
68 def __setitem__(self, key, val):
69 if isinstance(self._ntype, list):
---> 70 assert isinstance(val, dict), \
71 'Current HeteroNodeDataView has multiple node types, ' \
72 'please passing the node type and the corresponding data through a dict.'
AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.
It looks like part of the issue is the block has registered src_type
as a dsttype now:
print(blocks[0].dsttypes)
print(g.dsttypes)
['dst_type', 'src_type']
['dst_type']