Mini-batch training for model with aggregation of node features and edge features

Hi, I want to correct the codes of my current model / codes in the training loops for mini-batch training. The model takes both the node features and edge features for aggregation. Currently, the model fails because:

“DGLError: Expect number of features to match number of edges. Got 5814 and 2762 instead.”

Could you please suggest how to solve this issue? Thanks in advance!

Below is my current model (refer to How to use mutiple node features and edge features in GCN?):

class GNNLayer(nn.Module):
    def __init__(self, ndim, edim, out_dim, activation=F.relu, aggregator_type = 'sum'): 
        self.aggregator_types_map = {'sum': fn.sum, 'mean': fn.mean, 'max': fn.max}
        aggregator_type = aggregator_type.lower()
        assert aggregator_type in self.aggregator_types_map.keys(), f'InputError: Expect aggregator_type to be one of the following: \n{self.aggregator_types_map.keys()}' 
        self.aggregator_type = aggregator_type
        super(GNNLayer, self).__init__()
        self.W_msg = nn.Linear(ndim + edim, out_dim)       
        self.W_apply = nn.Linear(ndim + out_dim, out_dim) #ndim = original node features. #out_dim = aggregated features
        self.activation = activation

    def message_func(self, edges):        
        return {'m': self.activation(self.W_msg([edges.src['h'],['h']], dim=1)))}

    def forward(self, g_dgl, nfeats, efeats):
        # check type(g_dgl) : whether it is a DGL block or a DGL graph
        g_dgl_is_graph = isinstance(g_dgl, dgl.DGLGraph)
        g_dgl_is_block = g_dgl.is_block
        if g_dgl_is_block:    # case 1: input is a DGL block
            block = g_dgl
            h = nfeats
            # with g_dgl.local_scope():
            with block.local_scope():
                # g_dgl.ndata['h'] = nfeats 
                h_src = h
                h_dst = h[:block.number_of_dst_nodes()]
                block.srcdata['h'] = h_src
                block.dstdata['h'] = h_dst 

                # g_dgl.edata['h'] = efeats  
                block.edata['h'] = efeats 

                agg_func = self.aggregator_types_map[self.aggregator_type]
                # g_dgl.update_all(self.message_func, agg_func(msg='m', out='h_neigh'))
                block.update_all(self.message_func, agg_func(msg='m', out='h_neigh'))

                # g_dgl.ndata['h'] = self.activation(self.W_apply([g_dgl.ndata['h'], g_dgl.ndata['h_neigh']], dim=1)))

                # return g_dgl.ndata['h']
                return self.activation(self.W_apply(
                    [block.dstdata['h'], block.dstdata['h_neigh']], dim=1
        elif g_dgl_is_graph:    # input is a DGL graph
            with g_dgl.local_scope():
                g = g_dgl
                g.ndata['h'] = nfeats 
                g.edata['h'] = efeats  

                agg_func = self.aggregator_types_map[self.aggregator_type]
                g.update_all(self.message_func, agg_func(msg='m', out='h_neigh')) 
                g.ndata['h'] = self.activation(self.W_apply([g.ndata['h'], g.ndata['h_neigh']], dim=1)))

                return g.ndata['h']
    def reset_parameters(self):
        nn.init.constant_(self.W_msg.bias, 0.0)
        nn.init.constant_(self.W_apply.bias, 0.0)

class GNN_FC(nn.Module):
    def __init__(self, conv_ndims, edim, fc_ndim_out, emlp_nlayer, activation=F.relu, dropout=0.5, xavier_init=True, aggregator_type='sum'):
        super(GNN_FC, self).__init__()
        self.edge_prelayers = nn.ModuleList()
        for i in range(emlp_nlayer):
                nn.Linear(edim, edim)
        self.layers = nn.ModuleList()
        for _ndim_in, _ndim_out in conv_ndims:
            self.layers.append(GNNLayer(_ndim_in, edim, _ndim_out, activation, aggregator_type=aggregator_type))
            # self.layers.append(GNNLayer_v2(_ndim_in, edim, _ndim_out, activation, aggregator_type=aggregator_type))
        self.fc = nn.Linear(_ndim_out, fc_ndim_out)#, n_classes   # added
        self.dropout = nn.Dropout(p=dropout)
        self.xavier_init = xavier_init

    def forward(self, g, node_feat, edge_feat): 
        g_is_graph = isinstance(g, dgl.DGLGraph)
        g_is_list_of_blocks = (isinstance(g, list) and g[0].is_block) if not g_is_graph else False
        if not g_is_graph:
            assert isinstance(g, list) and len(g) == len(self.layers), "InputError for g: Expect either DGL graph or a list of DGL blocks. \nIf g is a list of DGL blocks, the length of the list should be the same as the number of conv layers in the model."
        h = node_feat
        e = edge_feat
        for i, layer in enumerate(self.edge_prelayers):
            e = layer(e)
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)

            if g_is_graph:
                h = layer(g, h, e)
            elif g_is_list_of_blocks:
                h = layer(g[i], h, e)

        h = self.fc(h)
        return h
    def reset_parameters(self):
        print('Weights : \t Xavier Init')
        print('Bias: \t 0')
        for layer in self.layers:
        for layer in self.edge_prelayers:

This is how I train my model (part of the codes of the training loop) (refer to 6.5 Implementing Custom GNN Module for Mini-batch Training and 6.1 Training GNN for Node Classification with Neighborhood Sampling):

train_graph_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(num_modellayers)
train_dataloader = DataLoader(
    indices = torch.nonzero(train_mask).squeeze(), 
    graph_sampler = train_graph_sampler,
    num_workers = 4

for epoch in tqdm(range(epochs)):       
    train_loss = 0

    for _input_nodes, _output_nodes, _blocks in tqdm(train_dataloader):

        # Move tensors to device
        _blocks = [ for _block in _blocks]

        # Forward pass
        _srcnode_feat = _blocks[0].srcdata['feat']
        _edge_feat = _blocks[0].edata['feat']
        _labels = _blocks[-1].dstdata['label'].cpu()

        if edge_feat is None:
            _out = model(_blocks, _srcnode_feat)
            _out = model(_blocks, _srcnode_feat, _edge_feat)

        if learning_type=='SSL':    # Semi-supervised learning
            _tensor_mask = torch.where(torch.isin(train_labelled_nodes,[0]
            _loss = _criterion(
                _labels[_tensor_mask].to(torch.int64), weight=torch.tensor(criterion_weight)
        elif learning_type == 'SL':    # Supervised learning   
            _loss = _criterion(_out,, weight=torch.tensor(criterion_weight))    # criterion(_out[train_mask], label[train_mask].to(torch.int64), weight=torch.tensor(criterion_weight))

        train_loss += _loss.item()

Currently the model is trained under supervised learning (SL) mode (i.e. learning_type == ‘SL’).

Hi @chpoonag, looks like you are using the edge features of the first block for the forward functions of all the blocks. This may be a problem because different blocks may have different number of edges.