Recently, I tried to implement a mini-batch training on a heterograph with HGTConv module. However, I met some trouble here that I cannot understand.
Beloew lists the error:
Traceback (most recent call last):
File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/HGT/GraphNodeClassification.py", line 441, in <module>
train_acc, val_acc, test_acc, precision, recall, f1 = train(graph, args)
^^^^^^^^^^^^^^^^^^
File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/HGT/GraphNodeClassification.py", line 249, in train
h_blocks = model[0](blocks)
^^^^^^^^^^^^^^^^
File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/HGT/HGT.py", line 300, in forward
h = layer(block, h_src, srctypes, etype, presorted=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/dgl/nn/pytorch/conv/hgtconv.py", line 165, in forward
g.srcdata["k"] = k
~~~~~~~~~^^^^^
File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/dgl/view.py", line 86, in __setitem__
assert isinstance(val, dict), (
AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.
Below lists my implementation code.
def forward(self, blocks):
"""
The forward part of the HGT.
Parameters
----------
blocks : list of block for mini-batch training
list of block, block is acquired through dgl sampler method.
Returns
-------
dict
The embeddings after the output projection.
"""
for layer, block in zip(self.hgt_layers, blocks):
# concat different node type embeddings
h_src = []
for srctype in block.srctypes:
h_src.append(block.srcdata['embedding'][srctype].squeeze(1))
h_src = torch.cat(h_src)
# get hgtconv required ntypes and etypes
srctypes = get_block_srctypes(block).to(self.device)
etype = get_block_etype(block, self.device).to(self.device)
# into hgtconv
# layer forward parameter
# layer(g, x, ntype, etype, *, presorted=False)
h = layer(block, h_src, srctypes, etype, presorted=False)
#to_hetero_feat: h, type, name
h_dict = to_hetero_feat(h, get_block_dsttypes(block), self.ntypes)
return h_dict
h = layer(block, h_src, srctypes, etype, presorted=False)
# This is where error ocurrs
where get_block_srctypes and get block_etype recive a current ‘block’ object and returns an ‘ntype’ and ‘etype’ 1D tensor that was required by HGTconv layer.
Thanks for your time in advance!