Dear all,
I’m trying to implement Molecular Graph Convolutions that was written in this paper https://arxiv.org/pdf/1704.01212.pdf (page 3 - Molecular Graph Convolutions, Kearnes et al. (2016)).
Through the list of code below, why my nodes’ dimension doesn’t change after I apply class NodeApplyModule(nn.Module)? Running this code produces n_nnOri with dimension 6x6 in each layer as opposed to 8x2 in the second layer of my architecture. As a result, an error of mismatch emerges.
Can anybody help me by pointing out which mistake I did make? Thanks a lot
def create_Data():
g = dgl.DGLGraph() ; g.add_nodes(5) ;
g.add_edge(0,1); g.add_edge(1,2); g.add_edge(2,3);
g.add_edge(0,4); g.add_edge(1,3); g.add_edge(2,4);
g.ndata['h'] = torch.tensor([[1., 4., 7.], [2., 3., 5.], [3., 6., 8.], [1., 2., 1.], [2., 9., 3.]])
g.edata['e'] = torch.tensor([[3.,5.], [4.,6.], [0.,1.], [3.,2.], [6.,7.], [4.,3.]])
labels = torch.randint(0, 2, (g.number_of_nodes(),1), dtype=torch.float32)
TTE_possible = [[True,False,False],[False,True,False],[False,False,True]]
TTE_indices = np.random.choice(len(TTE_possible), g.number_of_nodes(), replace=True)
TTE = [TTE_possible[i] for i in TTE_indices]
train_mask = torch.tensor([TTE[i][0] for i in range(len(TTE))])
test_mask = torch.tensor([TTE[i][1] for i in range(len(TTE))])
eval_mask = torch.tensor([TTE[i][2] for i in range(len(TTE))])
return g, train_mask, test_mask, eval_mask, labels
g, train_mask, test_mask, eval_mask, labels = create_Data()
labels = torch.reshape(labels, (g.number_of_nodes(),))
labels = labels.to(dtype=torch.long)
class NodeApplyModule(nn.Module):
def __init__(self, activation=None):
super(NodeApplyModule, self).__init__()
self.activation = activation
def forward(self, nodes):
h = nodes.data['h']
print("h NodeApply:",h)
if self.activation:
h = self.activation(h)
return {'h': h}
class GCNLayer(nn.Module):
def __init__(self, ndim_in, edims, ndim_out, activation, dropout, bias=True):
super(GCNLayer, self).__init__()
self.W0 = nn.Linear(ndim_in, ndim_out)
self.W1 = nn.Linear(ndim_out + ndim_out + edims, ndim_out)
self.W2 = nn.Linear(edims, edims)
self.W3 = nn.Linear(ndim_in + ndim_in, ndim_out)
self.W4 = nn.Linear(edims + ndim_out, edims)
if dropout: self.dropout = nn.Dropout(p=dropout)
else: self.dropout = 0.0
self.activation = activation
self.node_update = NodeApplyModule(activation)
self.reset_parameters()
def reset_parameters(self):
self.W0.reset_parameters(); self.W1.reset_parameters(); self.W2.reset_parameters(); self.W3.reset_parameters(); self.W4.reset_parameters()
def gcn_msg1(self, edges):
msgOri = torch.cat([edges.src['h'] , edges.dst['h'] , edges.data['e']], 1)
msg = self.W1(msgOri)
return {'m': msg}
def gcn_red1(self, nodes):
red = nodes.mailbox['m'].sum(1)
return {'h': red}
g.register_reduce_func(gcn_red1); g.register_message_func(gcn_msg1)
def e_update(self, edges):
e_nn = self.W2(edges.data['e']);
e_nn = F.relu(e_nn)
n_nnOri = torch.cat([edges.src['h'], edges.dst['h']], 1)
print("n_nnOri:",n_nnOri.shape)
n_nn = self.W3(n_nnOri);
n_nn = F.relu(n_nn)
e = self.W4(torch.cat([e_nn, n_nn], 1))
e = F.relu(e)
return {"e": e}
def forward(self, g, nfeats, efeats):
with g.local_scope():
g.ndata['h'] = nfeats; g.edata['e'] = efeats
g.apply_edges(self.e_update)
g.ndata['h'] = F.relu(self.W0(nfeats))
g.send_and_recv(g.edges(), message_func=self.gcn_msg1, reduce_func=self.gcn_red1,
apply_node_func=self.node_update)
h = g.ndata['h']; e = g.edata['e']
return h,e
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(GCNLayer(len(g.ndata['h'][1]), 2, 4, F.relu, 0.5))
self.layers.append(GCNLayer(4, 2, 2, F.relu, 0.5))
self.dropout = nn.Dropout(p=0.5)
def forward(self, g, nfeats, efeats):
for layer in self.layers:
h, e = layer(g, nfeats, efeats)
print("="*75)
return h
model = Net(); print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_fcn = torch.nn.CrossEntropyLoss()
for epoch in range(1):
model.train()
logits = model(g, g.ndata['h'], g.edata['e'])
logp = F.log_softmax(logits, 1)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad(); loss.backward(); optimizer.step()