Difference between readout (e.g. dgl.max_nodes) and Pooling layers

What’s the difference between having a readout of dgl.max_nodes and using MaxPooling? The same question applies to other readout functions and pooling layers. Example network below:

class GCN(nn.Module):
    def __init__(self, in_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, 64)
        self.conv2 = GraphConv(64, 128)
        self.conv3 = GraphConv(128, 64)
        self.maxpool = MaxPooling()
        self.dense = nn.Linear(64, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        h = F.relu(h)
        h = self.conv3(g, h)
        h = F.relu(h)
        g.ndata['h'] = h
        hg = dgl.max_nodes(g, 'h')
        return self.dense(hg)

Thanks!

There’s no difference as MaxPooling is a wrapper of max_nodes to make it an nn.Module.

1 Like

Thank you, do you know how to use the Pooling layer to replace the DGL readout function please? I keep getting an error.

Instead of

g.ndata['h'] = h
result = dgl.max_nodes(g, 'h')

You do

pool = MaxPooling()
result = pool(g, h)

May I know what was the error and what was the stack trace? I guess that merely replacing max_nodes with nn modules likely won’t solve the problem. Thanks!

1 Like

The below is the final model I’m using and it seems to be working. Changed over to SumPooling because that is what I needed.

class GCN(nn.Module):
    def __init__(self, in_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, 64)
        self.conv2 = GraphConv(64, 128)
        self.conv3 = GraphConv(128, 64)
        self.pool = SumPooling()
        self.dense = nn.Linear(64, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        h = F.relu(h)
        h = self.conv3(g, h)
        h = F.relu(h)
        hg = self.pool(g, h)
        hg = torch.tanh(hg)
        return self.dense(hg)

Thanks!

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.