Errors on using HGTConv

Recently, I tried to implement a mini-batch training on a heterograph with HGTConv module. However, I met some trouble here that I cannot understand.
Beloew lists the error:

Traceback (most recent call last):
  File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/HGT/GraphNodeClassification.py", line 441, in <module>
    train_acc, val_acc, test_acc, precision, recall, f1 = train(graph, args)
                                                          ^^^^^^^^^^^^^^^^^^
  File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/HGT/GraphNodeClassification.py", line 249, in train
    h_blocks = model[0](blocks)
               ^^^^^^^^^^^^^^^^
  File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/hhs/gene_pretrain/graph_pretrain/graph_model/HGT/HGT.py", line 300, in forward
    h = layer(block, h_src, srctypes, etype, presorted=False)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/dgl/nn/pytorch/conv/hgtconv.py", line 165, in forward
    g.srcdata["k"] = k
    ~~~~~~~~~^^^^^
  File "/data/anaconda3/envs/ocean/lib/python3.11/site-packages/dgl/view.py", line 86, in __setitem__
    assert isinstance(val, dict), (
AssertionError: Current HeteroNodeDataView has multiple node types, please passing the node type and the corresponding data through a dict.

Below lists my implementation code.

def forward(self, blocks):
        """
        The forward part of the HGT.
        
        Parameters
        ----------
        blocks : list of block for mini-batch training
            list of block, block is acquired through dgl sampler method.
        Returns
        -------
        dict
            The embeddings after the output projection.
        """
        for layer, block in zip(self.hgt_layers, blocks):

            # concat different node type embeddings
            h_src = []
            for srctype in block.srctypes:
                h_src.append(block.srcdata['embedding'][srctype].squeeze(1))
            h_src = torch.cat(h_src)
						
            # get hgtconv required ntypes and etypes
            srctypes = get_block_srctypes(block).to(self.device)
            etype = get_block_etype(block, self.device).to(self.device)
            
            # into hgtconv
            # layer forward parameter
        		# layer(g, x, ntype, etype, *, presorted=False)
            h = layer(block, h_src, srctypes, etype, presorted=False)

        #to_hetero_feat: h, type, name
        h_dict = to_hetero_feat(h, get_block_dsttypes(block), self.ntypes)

        return h_dict
h = layer(block, h_src, srctypes, etype, presorted=False)
# This is where error ocurrs

where get_block_srctypes and get block_etype recive a current ‘block’ object and returns an ‘ntype’ and ‘etype’ 1D tensor that was required by HGTconv layer.

Thanks for your time in advance!

Hi @Azai-yx, you need to convert the graph to homogeneous before calling HGTConv, here is a similar issue for reference dgl.to_homogeneous does not support DGLBlock MFGs · Issue #6296 · dmlc/dgl · GitHub.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.