Node and edge features in with GCN

Dear all,
I’ve been trying to understand how to use the GCN for node classification that takes into account node and edge features. Here is the glance of the features:
Node feature:
g_dgl.ndata[‘n_feat’]: tensor([[[1, 1, 1],[1, 1, 1],[0, 1, 1]],…,[[1, 1, 0],[1, 1, 1],[1, 1, 1]]]) -> torch.Size([116054, 3, 3])
Edge feature1:
g_dgl.edata[‘e_feat1’]: tensor([[0, 0, 1],…,[0, 0, 1]]) -> torch.Size([346320, 3])
Edge feature2:
g_dgl.edata[‘e_feat2’]:tensor([[0, 1, 1],…,[1, 0, 1]]) -> torch.Size([346320, 3])
Edge feature3:
g_dgl.edata[‘e_feat3’]:tensor([[0, 1, 1],…,[1, 0, 1]]) -> torch.Size([346320, 3])

So far, I have got several directions and come with the implementation to create GCN as followed:

class GCN(nn.Module):
  def __init__(self, dim_in, dim_out, activation, dropout):
    super(GCN, self).__init__()
    self.layers = nn.ModuleList()
    self.layers.append(GraphConv(dim_in, dim_hid, activation=activation))
    self.layers.append(GraphConv(dim_hid, dim_out))
    self.dropout = nn.Dropout(p=dropout)
  def forward(self, g_dgl, feats):
    h = feats
    for i, layer in enumerate(self.layers):
        if i != 0:
            h = self.dropout(h)
        h = layer(g_dgl, h)
    return h

The model is created with model = GCN(dim_in, dim_out, F.relu, dropout), train with model.train(), and perform inference with model(g_dgl, d_dgl.ndata["n_feat"]).
As seen, the class doesn’t take into account the edge’s features yet. I have been suggested with this class:

class GNNLayer(nn.Module):
  def __init__(self, node_dims, edge_dims, output_dims, activation):
    super(GNNLayer, self).__init__()
    self.W_msg = nn.Linear(node_dims + edge_dims, output_dims)
    self.W_apply = nn.Linear(output_dims * 2, output_dims)
    self.activation = activation
  def message_func(edges):
    return {'m': F.relu(self.W_msg(th.cat([edges.src['h'], edges.data['h']], 0)))}
  def forward(self, g, node_features, edge_features):
    with g.local_scope():
      g.ndata['n_feat'] = node_features
      g.edata['e_feat'] = edge_features
      g.update_all(self.message_func, fn.sum('m', 'h_neigh'))
      g.ndata['h'] = F.relu(self.W_apply(th.cat([g.ndata['h'], g.ndata['h_neigh']], 0)))
      return g.ndata['h']

With that class, I replace self.layers.append(GraphConv(dim_in, dim_hid, activation=activation)) with self.layers.append(GNNLayer(3, 3, 1, activation=activation)).

  1. However, I’m not sure yet if by just replacing that line, it already take into account all of my 3 edges features I mentioned.
  2. In addition, with self.layers.append(GraphConv(dim_in, dim_out, activation=activation)) , it is clear that the dim_out of previous layer becomes the dim_in for the next layer. But when GNNLayer class takes node and edge features, how is the same principle is applied (dim_out becomes dim_in)?

Thanks all for any suggestions and directions.

  1. It seems that you have 3 channels for your features and I assume you combine the three edge features with torch.stack([efeat1, efeat2, efeat3], dim=1)?
  2. I think you can just replace GraphConv(dim_in, dim_hid, activation=activation) with GNNLayer(node_dim_in, edge_dim_in, dim_hid, activation).
  3. node_dim_in for layer l+1 should be dim_hid for layer l while edge_dim_in will be kept unchanged across all layers since we are not updating edge representations.

Thanks @mufeili,

I tried the approach with the simple implementation as followed:

class GNNLayer2(nn.Module):
  def __init__(self, ndim_in, edims, ndim_out, activation):
    super(GNNLayer2, self).__init__()
    self.W_msg = nn.Linear(ndim_in + edims, ndim_out)
    self.W_apply = nn.Linear(ndim_out * 2, ndim_out)
    self.activation = activation
  def message_func(edges):
    return {'m': F.relu(self.W_msg(th.cat([edges.src['h'], edges.data['h']], 0)))}    
  def forward(self, g_dgl, node_features, edge_features):
    with g_dgl.local_scope():
      g_dgl.ndata['n_feat'] = node_features
      g_dgl.edata['e_feat'] = edge_features
      g_dgl.update_all(self.message_func, fn.sum('m', 'h_neigh'))
      g_dgl.ndata['h'] = F.relu(self.W_apply(th.cat([g.ndata['h'], g.ndata['h_neigh']], 0)))
      return g_dgl.ndata['h']

class GCN2(nn.Module):
  def __init__(self, ndim_in, ndim_out, edim, activation, dropout):
    super(GCN2, self).__init__()
    self.layers = nn.ModuleList()
    self.layers.append(GNNLayer2(ndim_in, edim, 50, activation))
    self.layers.append(GNNLayer2(50, edim, 25, activation))
    self.layers.append(GNNLayer2(25, edim, ndim_out, activation))
    self.dropout = nn.Dropout(p=dropout)
  def forward(self, g_dgl, feats):
    h = feats
    for i, layer in enumerate(self.layers):
        if i != 0:
            h = self.dropout(h)
        h = layer(g_dgl, h)
    return h

 e_feat1 = torch.tensor([[0, 0, 1],[0, 0, 1],[0, 0, 1]])
 e_feat2 = torch.tensor([[0, 1, 1],[1, 0, 1],[0, 0, 1]]) 
 e_feat3 = torch.tensor([[1, 0, 1],[1, 0, 1],[0, 0, 1]]) 
 e_feat = torch.stack([e_feat1, e_feat2, e_feat3], dim=1)
 g = DGLGraph([(0, 2), (0, 3), (1, 2)])
 g.edata['e_feat'] = e_feat
 g.ndata['n_feat'] = torch.tensor([[[1, 1, 1],[1, 1, 1],[0, 1, 1]],
                                   [[1, 1, 1],[1, 1, 1],[0, 1, 1]],
                                   [[1, 1, 1],[1, 1, 1],[0, 1, 1]],
                                   [[1, 1, 0],[1, 1, 1],[1, 1, 1]]])
 model = GCN2(3, 1, 3, F.relu, dropout)
 loss_fcn = torch.nn.CrossEntropyLoss()
 optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
 for epoch in range(2):
   model.train()
   logits = model(g.ndata['n_feat'], g.edata['e_feat'])
   loss = loss_fcn(logits[train_mask], labels[train_mask])
   optimizer.zero_grad()
   loss.backward()
   optimizer.step()
   acc = evaluate(model, features, labels, val_mask)
  1. It produces an error that forward() missing 1 required positional argument: 'edge_features'. I thought I pass the edge_features via model(g.ndata['n_feat'], g.edata['e_feat'])?
  2. I trying to understand the logic behind self.W_msg = nn.Linear(ndim_in + edims, ndim_out) and self.W_apply = nn.Linear(ndim_out * 2, ndim_out), but I couldn’t find a similar example in the library. What actually these 2 lines do? Why does it ndim_out * 2 and not other numbers?

Thanks for the help. I appreciate it so much.

  1. You need to pass the graph, node features and edge features. Currently you are not passing the graph.
  2. self.W_msg and self.W_apply separately define two linear layers. self.W_msg projects the concatenated neighbor and edge features before passing the message. self.W_apply projects the concatenation of the old node representations and the aggregated new representations. This is why we have ndim_out * 2 here.
2 Likes

Thanks so much @mufeili,

but now I have got another error saying that message_func() takes 1 positional argument but 2 were given. I modified the line that causes an error to g_dgl.update_all(self.message_func(g_dgl.edges), fn.sum('m', 'h_neigh')), but the error remains the same. So why it says I supply message_func with 2 arguments?

Thanks a lot

  1. By passing the graph, node features and edge features, I was referring to model(g, g.ndata['n_feat'], g.edata['e_feat']).
  2. You don’t need to explicitly pass edges in g_dgl.update_all, which is automatically handled.

Thanks @mufeili,

If g_dgl.update_all handles edges automatically, and when I call message_func() in it to concatenate neighbour and edge features (through 'm': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], 0)))), how does it know the edges.src['h'] and edges.data['h']? It feels like edges.src['h'] and edges.data['h'] are unknown in message_func()

If I’m correct, I can fill edges.data['h'] = edges.data['e_feat'], right? but how about edges.src['h']? is this the nodes features?

Thanks so much

  1. This is realized via EdgeBatch generated internally.
  2. You cannot change the edge features in the message function.

Thanks @mufeili,

Right if I cannot change the edge features in the message function, then edges.data['h'] = edges.data['e_feat'] is not correct then. Got it. But I still got error KeyError: 'h', when execute return {'m': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], 0)))}. How is that happen if the edge’s processes are generated internally?

Thanks so much

That’s expected because you did not set node/edge features under keyname 'h'. Change

g_dgl.ndata['n_feat'] = node_features
g_dgl.edata['e_feat'] = edge_features

to

g_dgl.ndata['h'] = node_features
g_dgl.edata['h'] = edge_features

Thanks @mufeili,

I have spent some hours to understand how the code works and get this final very simple implementation.

class GNNLayer2(nn.Module):
  def __init__(self, ndim_in, edims, ndim_out, activation):
    super(GNNLayer2, self).__init__()
    self.W_msg = nn.Linear(ndim_in + edims, ndim_out)
    self.W_apply = nn.Linear(ndim_out * 2, ndim_out)
    self.activation = activation

  def message_func(self, edges):
    return {'m': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], 0)))}  
    
  def forward(self, g_dgl):
    with g_dgl.local_scope():
      g = g_dgl
      g.ndata['h'] = g.ndata['n_feat']
      g.edata['h'] = g.edata['e_feat']
      g.update_all(self.message_func, fn.sum('m', 'h_neigh'))
      g.ndata['h'] = F.relu(self.W_apply(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 0)))
      return g.ndata['h']

class GCN2(nn.Module):
  def __init__(self, ndim_in, ndim_out, edim, activation, dropout):
    super(GCN2, self).__init__()
    self.layers = nn.ModuleList()
    self.layers.append(GNNLayer2(ndim_in, edim, 50, activation))
    self.layers.append(GNNLayer2(50, edim, 25, activation))
    self.layers.append(GNNLayer2(25, edim, ndim_out, activation))
    self.dropout = nn.Dropout(p=dropout)

  def forward(self, g_dgl):
    g = g_dgl
    h = g.ndata['n_feat']
    for i, layer in enumerate(self.layers):
        if i != 0:
            h = self.dropout(h)
        h = layer(g)
    return h

model = GCN2(3, 1, 3, F.relu, 0.5)

Which produce a model as followed:

Model: GCN2(
  (layers): ModuleList(
    (0): GNNLayer2(
      (W_msg): Linear(in_features=6, out_features=50, bias=True)
      (W_apply): Linear(in_features=100, out_features=50, bias=True)
    )
    (1): GNNLayer2(
      (W_msg): Linear(in_features=53, out_features=25, bias=True)
      (W_apply): Linear(in_features=50, out_features=25, bias=True)
    )
    (2): GNNLayer2(
      (W_msg): Linear(in_features=28, out_features=1, bias=True)
      (W_apply): Linear(in_features=2, out_features=1, bias=True)
    )
  )
  (dropout): Dropout(p=0.5, inplace=False)
)
  1. I got mismatch in size though, size mismatch, m1: [18 x 3], m2: [6 x 50] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41 As you can see, the first layer has 6 in_features and 50 out_features. But why m1 has [18x3]? I think it has something to do with self.W_msg = nn.Linear(ndim_in + edims, ndim_out), is it? How should I cope with it?

  2. Correct me if I’m wrong, the activation function has been applied in forward in GCNNLayer2, right? How would it be different if I apply the activation function in forward in GCN2 too, such as

    if self.activation: g.ndata[‘h’] = self.activation(g.ndata[‘h’])

Thanks for your help

What’s the shape of g.ndata['h'] and g.edata['h'] now?

The g.ndata['h'] = torch.Size([4, 3, 3] and g.edata['h'] = torch.Size([3, 3, 3].

Can you change

return {'m': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], 0)))}  

to

return {'m': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], 1)))}  

and

g.ndata['h'] = F.relu(self.W_apply(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 0)))

to

g.ndata['h'] = F.relu(self.W_apply(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 2)))

I actually have tried it before, but it still throw the same error.

Try this one

import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer2(nn.Module):
    def __init__(self, ndim_in, edims, ndim_out, activation):
        super(GNNLayer2, self).__init__()
        self.W_msg = nn.Linear(ndim_in + edims, ndim_out)
        self.W_apply = nn.Linear(ndim_in + ndim_out, ndim_out)
        self.activation = activation

    def message_func(self, edges):
        return {'m': F.relu(self.W_msg(torch.cat([edges.src['h'], edges.data['h']], 2)))}

    def forward(self, g_dgl, nfeats, efeats):
        with g_dgl.local_scope():
            g = g_dgl
            g.ndata['h'] = nfeats
            g.edata['h'] = efeats
            g.update_all(self.message_func, fn.sum('m', 'h_neigh'))
            g.ndata['h'] = F.relu(self.W_apply(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 2)))
            return g.ndata['h']


class GCN2(nn.Module):
    def __init__(self, ndim_in, ndim_out, edim, activation, dropout):
        super(GCN2, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GNNLayer2(ndim_in, edim, 50, activation))
        self.layers.append(GNNLayer2(50, edim, 25, activation))
        self.layers.append(GNNLayer2(25, edim, ndim_out, activation))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, nfeats, efeats):
        for i, layer in enumerate(self.layers):
            if i != 0:
                nfeats = self.dropout(nfeats)
            nfeats = layer(g, nfeats, efeats)
        return nfeats.sum(1)

if __name__ == '__main__':
    model = GCN2(3, 1, 3, F.relu, 0.5)
    g = dgl.DGLGraph([[0, 2], [2, 3]])
    nfeats = torch.randn((g.number_of_nodes(), 3, 3))
    efeats = torch.randn((g.number_of_edges(), 3, 3))
    model(g, nfeats, efeats)
2 Likes

Thanks @mufelli and I’m sorry for the late response about my thread. It does work and thank you so much for your help.

With this layer, we would expect logits=model(g, n_feat, e_feat) with the size of torch.Size([4, 1]), where 4 represents the number of nodes. I also want to do a binary classification with torch.nn.BCEWithLogitsLoss() with my target label tensor ([[0.],[0.],[0.],[1.]]) and train_mask tensor([False, False, True, False]). Similar to the GCN with binary classification example with Cora dataset. The code that I use is similar with an example of GCN:

logits = model(g, n_feat, e_feat)
loss_fcn = torch.nn.BCEWithLogitsLoss();
loss = loss_fcn(logits[train_mask], labels[train_mask])

So maybe a stupid question, when I suppose to declare the total number of the labels I want to predict (in this case 2, being [0.] or [1.]) since they are not explicitly declared in the GCN2 class? and what if I do a multiclasses classification?

Thank you

In the case of binary classification, declare ndim_out to be 1 and use torch.nn.BCEWithLogitsLoss. In the case of more than two classes, declare ndim_out to be the number of classes and use torch.nn.CrossEntropyLoss().

Hey!
I was just going through the code, shouldn’t it be:
g.ndata['h'] = self.activation(self.W_apply(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 2))) … ?

Yes, that is correct.