Help: Error Expect number of features to match number of nodes (len(u)). Got 96 and 1 instead

class RGAT(nn.Module):
    def __init__(self, in_dim,
                 out_dim, 
                 h_dim, 
                 etypes, 
                 num_heads,
                 num_hidden_layers=1, 
                 dropout=0,
                ):
        super(RGAT, self).__init__()
        self.rel_names = etypes
        self.layers = nn.ModuleList()
        self.conv = dglnn.HeteroGraphConv({
            rel: dglnn.GATv2Conv(in_dim, h_dim, num_heads=num_heads, bias=True, allow_zero_in_degree=False)
            for rel in etypes
        })
        # input 2 hidden
        self.layers.append(RGATLayer(
            in_dim, h_dim, num_heads, self.rel_names, activation=F.relu, dropout=dropout, last_layer_flag=False))
        for i in range(num_hidden_layers):
            self.layers.append(RGATLayer(
                h_dim*num_heads, h_dim, num_heads, self.rel_names, activation=F.relu, dropout=dropout, last_layer_flag=False
            ))
        self.layers.append(RGATLayer(
            h_dim*num_heads, out_dim, num_heads, self.rel_names, activation=None, last_layer_flag=True))
        return

    def forward(self, hg, h_dict=None,embed=False, eweight=None, **kwargs):
        if embed:
            h_dict = self.conv(hg, h_dict)
            out_put = {}
            for n_type, h in h_dict.items():
                out_put[n_type] = h.flatten(1)
            return out_put
        if hasattr(hg, 'ntypes'):
            # full graph training,
            for layer in self.layers:
                h_dict = layer(hg, h_dict)
        else:
            # minibatch training, block
            for layer, block in zip(self.layers, hg):
                h_dict = layer(block, h_dict)
        return h_dict


class RGATLayer(nn.Module):

    def __init__(self,
                 in_feat,
                 out_feat,
                 num_heads,
                 rel_names,
                 activation=None,
                 dropout=0.0,
                 last_layer_flag=False,
                 bias=True):
        super(RGATLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.num_heads = num_heads
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.last_layer_flag=last_layer_flag
        self.conv = dglnn.HeteroGraphConv({
            rel: dglnn.GATv2Conv(in_feat, out_feat, num_heads=num_heads, bias=bias, allow_zero_in_degree=False)
            for rel in rel_names
        })

    def forward(self, g, h_dict):
        h_dict = self.conv(g, h_dict)
        out_put = {}
        for n_type, h in h_dict.items():
            if self.last_layer_flag:
                h = h.mean(1)
            else:
                h = h.flatten(1)
            out_put[n_type] = h.squeeze()
        return out_put

I am trying to run Hetero PGExplainer on the RGAT model. When I perform the model training, the code works fine but when the training of the Explainer starts, with embed = True, the return of the h_dict gives the error. However, applying the same using GraphConvLayer instead of GATLayer does not give any issues.

Could you provide a minimal reproducible example as well as the complete stack trace?

While I solved the initial issue by flattening the return, the issue still persists in the later stage. This issue is not there when I use the GraphConv layer. Here is a reproducible code :

import dgl
import dgl.function as fn
import dgl.nn.pytorch as nn
import torch as th
import torch
import numpy as np
class RGAT(th.nn.Module):
        def __init__(
            self, in_feats, embed_dim, out_feats, canonical_etypes, num_heads = 3
        ):
            super(RGAT, self).__init__()
           
            self.conv = nn.HeteroGraphConv(
                {
                    c_etype: nn.GATv2Conv(in_feats, embed_dim, num_heads=num_heads)
                    for c_etype in canonical_etypes
                }
            )
            self.conv2= nn.HeteroGraphConv(
                {
                    c_etype: nn.GATv2Conv( embed_dim*num_heads, out_feats, num_heads=num_heads)
                    for c_etype in canonical_etypes
                }
            )

        def forward(self, g, h, embed=False, edge_weight=None):
        
            h = self.conv(g, h)
            out_put = {}
            for n_type, h in h.items():
                h = h.flatten(1)
                out_put[n_type] = h
            if  embed:
                return out_put
            h = self.conv2(g, out_put)
            out_put = {}
            for n_type, h in h.items():
                h = h.mean(1)
                out_put[n_type] = h.squeeze()
            return out_put


# Load dataset
input_dim = 5
hidden_dim = 5
num_classes = 2
data_dict = {
    ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
    ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([0, 1])),
    ('user', 'plays', 'game'): (torch.tensor([0, 1]), torch.tensor([0, 1])),
    ('user', 'plays', 'game'): (torch.tensor([2, 3]), torch.tensor([1, 0]))
}
g = dgl.heterograph(data_dict)
transform = dgl.transforms.AddReverse()
g = transform(g)
embeds = th.nn.ParameterDict()
for ntype in g.ntypes:
    embed = th.nn.Parameter(th.Tensor(g.num_nodes(ntype), 5))
    th.nn.init.xavier_uniform_(embed, gain=th.nn.init.calculate_gain("relu"))
    embeds[ntype] = embed

model = RGAT(input_dim, hidden_dim, num_classes, g.canonical_etypes)
optimizer = th.optim.Adam(model.parameters())
for epoch in range(10):
    logits = model(g, embeds)['user']
    loss = th.nn.functional.cross_entropy(logits, th.tensor([0,1,1,0]))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(epoch)
explainer = dgl.nn.HeteroPGExplainer(
    model, hidden_dim, num_hops=1, explain_graph=False
)
init_tmp, final_tmp = 5.0, 1.0
optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
pg_feat = {item: embeds[item].data for item in embeds}
for epoch in range(20):
    tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
    loss = explainer.train_step_node(
        { ntype: g.nodes(ntype) for ntype in g.ntypes },
        g, pg_feat, tmp
    )
    optimizer_exp.zero_grad()
    loss.backward()
    optimizer_exp.step()

And the stack trace is :
Traceback (most recent call last):
File “/home/iamroot/node_classification/try2.py”, line 79, in
loss = explainer.train_step_node(
File “/home/iamroot/edge/lib/python3.10/site-packages/dgl/nn/pytorch/explain/pgexplainer.py”, line 695, in train_step_node
prob, _, batched_graph, inverse_indices = self.explain_node(
File “/home/iamroot/edge/lib/python3.10/site-packages/dgl/nn/pytorch/explain/pgexplainer.py”, line 1052, in explain_node
batched_embed = self.elayers(batched_embed)
File “/home/iamroot/edge/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/iamroot/edge/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1520, in _call_impl
return forward_call(*args, **kwargs)
File “/home/iamroot/edge/lib/python3.10/site-packages/torch/nn/modules/container.py”, line 217, in forward
input = module(input)
File “/home/iamroot/edge/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/iamroot/edge/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1520, in _call_impl
return forward_call(*args, **kwargs)
File “/home/iamroot/edge/lib/python3.10/site-packages/torch/nn/modules/linear.py”, line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (24x45 and 15x64)