Stacking multiple EGNN layers

Greetings! I have the following code from DGL for EGNN, which has only one layer. I am curious about how I stack multiple EGNN layers.

    class EGNNConv(nn.Module):
        def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
            super(EGNNConv, self).__init__()

            self.in_size = in_size
            self.hidden_size = hidden_size
            self.out_size = out_size
            self.edge_feat_size = edge_feat_size
            act_fn = nn.SiLU()

            # \phi_e
            self.edge_mlp = nn.Sequential(
                # +1 for the radial feature: ||x_i - x_j||^2
                nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
                act_fn,
                nn.Linear(hidden_size, hidden_size),
                act_fn,
            )

            # \phi_h
            self.node_mlp = nn.Sequential(
                nn.Linear(in_size + hidden_size, hidden_size),
                act_fn,
                nn.Linear(hidden_size, out_size),
            )

            # \phi_x
            self.coord_mlp = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                act_fn,
                nn.Linear(hidden_size, 1, bias=False),
            )

        def message(self, edges):
            """message function for EGNN"""
            # concat features for edge mlp
            if self.edge_feat_size > 0:
                f = torch.cat(
                    [
                        edges.src["h"],
                        edges.dst["h"],
                        edges.data["radial"],
                        edges.data["a"],
                    ],
                    dim=-1,
                )
            else:
                f = torch.cat(
                    [edges.src["h"], edges.dst["h"], edges.data["radial"]], dim=-1
                )

            msg_h = self.edge_mlp(f)
            msg_x = self.coord_mlp(msg_h) * edges.data["x_diff"]

            return {"msg_x": msg_x, "msg_h": msg_h}

        def forward(self, graph, node_feat, coord_feat, edge_feat=None):

                with graph.local_scope():
                    # node feature
                    graph.ndata["h"] = node_feat
                    # coordinate feature
                    graph.ndata["x"] = coord_feat
                    # edge feature
                    if self.edge_feat_size > 0:
                        assert edge_feat is not None, "Edge features must be provided."
                        graph.edata["a"] = edge_feat
                    # get coordinate diff & radial features
                    graph.apply_edges(fn.u_sub_v("x", "x", "x_diff"))
                    graph.edata["radial"] = (
                        graph.edata["x_diff"].square().sum(dim=1).unsqueeze(-1)
                    )
                    # normalize coordinate difference
                    graph.edata["x_diff"] = graph.edata["x_diff"] / (
                        graph.edata["radial"].sqrt() + 1e-30
                    )
                    graph.apply_edges(self.message)
                    graph.update_all(fn.copy_e("msg_x", "m"), fn.mean("m", "x_neigh"))
                    graph.update_all(fn.copy_e("msg_h", "m"), fn.sum("m", "h_neigh"))

                    h_neigh, x_neigh = graph.ndata["h_neigh"], graph.ndata["x_neigh"]

                    h = self.node_mlp(torch.cat([node_feat, h_neigh], dim=-1))
                    x = coord_feat + x_neigh
                    return h, x

   model = EGNNConv(num_node_feat, num_node_feat, out_dim, num_edge_feat)

Thanks in advance!

Hi, you can stack multiple layers like other NN layers. Here is an example:

g = ... # input graph
h = ... # node feature
x = ... # coordinate feature
net = nn.ModuleList([EGNNConv(size, size, size) for _ in range(4)])
for layer in net:
    h, x = layer(g, h, x)

Thanks a lot! That works! Greatly appreciated!

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.