Implementation of simple Graph Network (https://arxiv.org/abs/1806.01261)

I would like to implement a simple graph network as described in https://arxiv.org/abs/1806.01261, with some simplifying assumptions.

First I want to do an edge embedding by concatenating two nodes’ features joined by an edge. Then I want to create the embedding with a fully connected network on the concatenated vector.

Secondly, I want to sum all the edge embeddings connected to a node and feed them into a second fully connected network to get a node embedding.

This process seems to work when I only have one of those layers (see the code bellow). However, when I add more than one I run into problems. First, because the node embedding dimension is different from the input feature dimension (which seems to work fine if I redefine Graph.ndata[‘node_input_embedding’] = Graph.ndata[‘node_output_embedding’]) however I get problems when I have to set the new edge_embedding, since it doesn’t have the same dimension as the previous one. I am wondering what is the most efficient way to adapt my code to be able to add more layers.

    import dgl
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    from fcn import FullyConnectedModel

    class GraphLayer(nn.Module):


        def __init__(self,-
                    input_node_embedding_dim,
                    edge_embedding_dim,
                    output_node_embedding_dim,
                    ):
            """
            Args:
                input_node_embedding_dim: dimension of the input node emebeddings (featues for first layer, other output previous)
                edge_embedding_dim: dimension of the edge embeddings.
                output_node_embedding_dim: dimension of the output node emebeddings-

            """
            super().__init__()

            self.input_node_embedding_dim = input_node_embedding_dim
            self.edge_embedding_dim = edge_embedding_dim
            self.output_node_embedding_dim = output_node_embedding_dim

            # Initialize MLPs (One for edge embedding, another one for node embedding)
            self.mlp_edges = FullyConnectedModel(n_features = int(self.input_node_embedding_dim*2),
                                            n_hidden_layers=2,
                                            n_hidden_dim=128,
                                            n_output_dim=self.edge_embedding_dim)

            self.mlp_nodes = FullyConnectedModel(n_features = self.edge_embedding_dim,
                                            n_hidden_layers=2,
                                            n_hidden_dim=128,
                                            n_output_dim=self.output_node_embedding_dim)

        def edge_update(self, edges):
            edge_features = torch.cat([edges.src['node_input_embedding'], edges.dst['node_input_embedding']],
                                  dim=1)
            edge_features = self.mlp_edges(edge_features)
            return {'edge_embedding': edge_features}

        def node_update(self, nodes):
            sum_edge_embeddings = torch.sum(nodes.mailbox["edge_embedding"], dim=1)
            node_embedding = self.mlp_nodes(sum_edge_embeddings)
            return {'node_output_embedding': node_embedding}

        def forward(self, graph):

            graph.send(graph.edges(), self.edge_update)
            graph.recv(graph.nodes(), self.node_update)

            return graph.ndata['node_output_embedding']

    class GraphNetwork(nn.Module):

        def __init__(self,
                    node_features_dim,
                    edge_embedding_dim,
                    output_node_embedding_dim):

            super().__init__()

            self.graph_layer_1 = GraphLayer(node_features_dim,
                                        edge_embedding_dim,
                                        output_node_embedding_dim)


        def forward(self, graph):
            logits = self.graph_layer_1(graph)
            return logits


    # -----------------------------------------------------------------------------
    # TEST MODEL
    # -----------------------------------------------------------------------------

    if __name__ == '__main__':

        n_nodes = 3

        G = dgl.DGLGraph()

        G.add_nodes(n_nodes)
        G.ndata["node_input_embedding"] = torch.zeros((n_nodes, 1))
        G.ndata["node_input_embedding"][0, 0] = 1.0
        G.ndata["node_input_embedding"][1, 0] = 2.0
        G.ndata["node_input_embedding"][2, 0] = 3.0

        # Connect all nodes
        G.add_edges([1,2], 0)
        G.add_edges([0,2], 1)
        G.add_edges([0,1], 2)

        node_features_dim=1
        edge_embedding_dim=2
        output_node_embedding_dim=5

        graphnet = GraphNetwork(node_features_dim,
                            edge_embedding_dim,
                            output_node_embedding_dim)

        print(graphnet(G))
        assert graphnet(G).shape[0] == n_nodes-
        assert graphnet(G).shape[-1] == output_node_embedding_dim

        pytorch_total_params = sum(p.numel() for p in graphnet.parameters() if p.requires_grad)

        print(f'Model parameters: {pytorch_total_params}')

Hi,

You can change

        graph.send(graph.edges(), self.edge_update)
        graph.recv(graph.nodes(), self.node_update)

to

        graph.update_all(self.edge_update, self.node_update)

The reason is that the feature dimension is fixed for each frame(g.ndata[‘embeding’]), you cannot partially change them in message passing. But update_all will update the whole frame with a new one(somehow equivalent to your Graph.ndata[‘node_input_embedding’] = Graph.ndata[‘node_output_embedding’]), therefore it won’t be a problem.

That’s great ! Thank you very much. I’m just wondering, that would not cause problems with backpropagation?

(Also, totally unrelated, do you know what is the right way to insert code in here?)

No. You can imagine if you write x=th.nn.relu(x), the backpropagation won’t break.
For code you can use ` to wrap your code for inline use. For block of code, put
```python at the first line and ``` at the last line. This is the same as github’s markdown syntax.

Thanks! This is my final implementation:

import dgl 
import torch
import torch.nn as nn
import torch.nn.functional as F

from fcn import FullyConnectedModel

class GraphLayer(nn.Module):
    def __init__(self,-
                input_node_dim,
                edge_dim,
                output_node_dim,
                n_hidden_edges=2,
                n_hidden_nodes=2,
                dim_hidden_edges=128,
                dim_hidden_nodes=128,
                ):
        """
        Args:
            input_node_dim: dimension of the input node emebeddings (featues for first layer, other output previous)
            edge_dim: dimension of the edge embeddings.
            output_node_dim: dimension of the output node emebeddings-
            n_hidden_edges: number of hidden layers for the MLP defining the edge embedding
            n_hidden_nodes: number of hidden layers for the MLP defining the node embedding
            dim_hidden_edges: number of neurons per layer in the MLP defining the edge embedding
            dim_hidden_nodes: number of neurons per layer in the MLP defining the node embedding

        """
        super().__init__()

        self.input_node_dim = input_node_dim
        self.edge_dim = edge_dim
        self.output_node_dim = output_node_dim
        # Initialize MLPs (One for edge embedding, another one for node embedding)
        self.mlp_edges = FullyConnectedModel(n_features = int(self.input_node_dim*2),
                                        n_hidden_layers=n_hidden_edges,
                                        n_hidden_dim=dim_hidden_edges,
                                        n_output_dim=self.edge_dim)
        self.mlp_nodes = FullyConnectedModel(n_features = self.edge_dim,
                                        n_hidden_layers=n_hidden_nodes,
                                        n_hidden_dim=dim_hidden_nodes,
                                        n_output_dim=self.output_node_dim)
    def edge_update(self, edges):
        """
        Args:
            edges: batch of edges

        Returns:
            edge embedding dictionary
        """
        edge_features = torch.cat([edges.src['node_input_embedding'], edges.dst['node_input_embedding']],
                              dim=1)
        edge_features = self.mlp_edges(edge_features)
        return {'edge_embedding': edge_features}


    def node_update(self, nodes):
        """
        Args:
            nodes: batch of nodes

        Returns:
            node embedding dictionary
        """
        sum_edge_embeddings = torch.sum(nodes.mailbox["edge_embedding"], dim=1)
        node_embedding = self.mlp_nodes(sum_edge_embeddings)
        return {'node_output_embedding': node_embedding}

    def forward(self, graph):
        """
        Args:
            graph: dgl grpah

        Returns:
            Node embedding
        """
        graph.update_all(self.edge_update, self.node_update)
        return graph.ndata['node_output_embedding']

class GraphNetwork(nn.Module):
    def __init__(self,
                node_features_dim,
                edge_dim,
                output_node_dim,
                n_hidden_edges=2,
                n_hidden_nodes=2,
                dim_hidden_edges=128,
                dim_hidden_nodes=128,
                ):
-
        '''
        Args:
            node_features_dim: number of input features per node
            edge_dim: size of the embedding of the edges connecting nodes
            output_node_dim: size of the output node embedding
            n_hidden_edges: number of hidden layers for the MLP defining the edge embedding
            n_hidden_nodes: number of hidden layers for the MLP defining the node embedding
            dim_hidden_edges: number of neurons per layer in the MLP defining the edge embedding
            dim_hidden_nodes: number of neurons per layer in the MLP defining the node embedding
        '''

        super().__init__()
        self.graph_layer_1 = GraphLayer(node_features_dim,
                                    edge_dim,
                                    output_node_dim,
                                    n_hidden_edges=n_hidden_edges,
                                    n_hidden_nodes=n_hidden_nodes,
                                    dim_hidden_edges=dim_hidden_edges,
                                    dim_hidden_nodes=dim_hidden_nodes,
                                    )

        self.graph_layer_2 = GraphLayer(output_node_dim,
                                    edge_dim,
                                    output_node_dim,
                                    n_hidden_edges=n_hidden_edges,
                                    n_hidden_nodes=n_hidden_nodes,
                                    dim_hidden_edges=dim_hidden_edges,
                                    dim_hidden_nodes=dim_hidden_nodes,
                                    )


    def forward(self, graph, features):
        '''
        Args:
            graph: dgl graph
        Returns:
            logits (per node prediction)
        '''
        graph.ndata['node_input_embedding'] = features
        logits = self.graph_layer_1(graph)
        graph.ndata['node_input_embedding'] = logits
        logits = self.graph_layer_2(graph)
        return logits

# -----------------------------------------------------------------------------
# TEST MODEL
# -----------------------------------------------------------------------------

if __name__ == '__main__':

    n_nodes = 3

    G = dgl.DGLGraph()

    G.add_nodes(n_nodes)
    G.ndata["node_input_embedding"] = torch.zeros((n_nodes, 1))
    G.ndata["node_input_embedding"][0, 0] = 1.0
    G.ndata["node_input_embedding"][1, 0] = 2.0
    G.ndata["node_input_embedding"][2, 0] = 3.0

    # Connect all nodes
    G.add_edges([1,2], 0)
    G.add_edges([0,2], 1)
    G.add_edges([0,1], 2)

    node_features_dim=1
    edge_dim=2
    output_node_dim=5

    graphnet = GraphNetwork(node_features_dim,
                        edge_dim,
                        output_node_dim)

    output = graphnet(G)
    assert output.shape[0] == n_nodes-
    assert output.shape[-1] == output_node_dim

    pytorch_total_params = sum(p.numel() for p in graphnet.parameters() if p.requires_grad)

    print(f'Model parameters: {pytorch_total_params}')
                                                               

I’m just not very happy about this line: graph.ndata[‘node_input_embedding’] = logits

Do you have any suggestions for improvement?