Hi All,
I got an error as title and my gnn is very similar with batch example of gcn:
https://docs.dgl.ai/tutorials/basics/4_batch.html
class NodeApplyModule(nn.Module):
"""Update the node feature hv with ReLU(Whv+b)."""
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
# Initialize the node features with h.
g.ndata['h'] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
the only different is the initial input ():
batch gcn is h = g.in_degrees().view(-1, 1).float()
and mine is h = g.ndata['attr']
which has already exist before batched the graphs.
I don’t know why the example can work since it update g.ndata[‘h’] from n_node X 1 to n_node X 256. In my case the apply_mod update g.ndata[‘h’] from n_node X 7 to n_node X 64.
I try to use
from dgl.frame import Scheme
nodes.data.schemes['h'] = Scheme(shape=(64,), dtype=torch.float32)
to bypass this problem but it still doesn’t work.
very urgent due day…
thanks!