I would like to implement a simple graph network as described in https://arxiv.org/abs/1806.01261, with some simplifying assumptions.
First I want to do an edge embedding by concatenating two nodes’ features joined by an edge. Then I want to create the embedding with a fully connected network on the concatenated vector.
Secondly, I want to sum all the edge embeddings connected to a node and feed them into a second fully connected network to get a node embedding.
This process seems to work when I only have one of those layers (see the code bellow). However, when I add more than one I run into problems. First, because the node embedding dimension is different from the input feature dimension (which seems to work fine if I redefine Graph.ndata[‘node_input_embedding’] = Graph.ndata[‘node_output_embedding’]) however I get problems when I have to set the new edge_embedding, since it doesn’t have the same dimension as the previous one. I am wondering what is the most efficient way to adapt my code to be able to add more layers.
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from fcn import FullyConnectedModel
class GraphLayer(nn.Module):
def __init__(self,-
input_node_embedding_dim,
edge_embedding_dim,
output_node_embedding_dim,
):
"""
Args:
input_node_embedding_dim: dimension of the input node emebeddings (featues for first layer, other output previous)
edge_embedding_dim: dimension of the edge embeddings.
output_node_embedding_dim: dimension of the output node emebeddings-
"""
super().__init__()
self.input_node_embedding_dim = input_node_embedding_dim
self.edge_embedding_dim = edge_embedding_dim
self.output_node_embedding_dim = output_node_embedding_dim
# Initialize MLPs (One for edge embedding, another one for node embedding)
self.mlp_edges = FullyConnectedModel(n_features = int(self.input_node_embedding_dim*2),
n_hidden_layers=2,
n_hidden_dim=128,
n_output_dim=self.edge_embedding_dim)
self.mlp_nodes = FullyConnectedModel(n_features = self.edge_embedding_dim,
n_hidden_layers=2,
n_hidden_dim=128,
n_output_dim=self.output_node_embedding_dim)
def edge_update(self, edges):
edge_features = torch.cat([edges.src['node_input_embedding'], edges.dst['node_input_embedding']],
dim=1)
edge_features = self.mlp_edges(edge_features)
return {'edge_embedding': edge_features}
def node_update(self, nodes):
sum_edge_embeddings = torch.sum(nodes.mailbox["edge_embedding"], dim=1)
node_embedding = self.mlp_nodes(sum_edge_embeddings)
return {'node_output_embedding': node_embedding}
def forward(self, graph):
graph.send(graph.edges(), self.edge_update)
graph.recv(graph.nodes(), self.node_update)
return graph.ndata['node_output_embedding']
class GraphNetwork(nn.Module):
def __init__(self,
node_features_dim,
edge_embedding_dim,
output_node_embedding_dim):
super().__init__()
self.graph_layer_1 = GraphLayer(node_features_dim,
edge_embedding_dim,
output_node_embedding_dim)
def forward(self, graph):
logits = self.graph_layer_1(graph)
return logits
# -----------------------------------------------------------------------------
# TEST MODEL
# -----------------------------------------------------------------------------
if __name__ == '__main__':
n_nodes = 3
G = dgl.DGLGraph()
G.add_nodes(n_nodes)
G.ndata["node_input_embedding"] = torch.zeros((n_nodes, 1))
G.ndata["node_input_embedding"][0, 0] = 1.0
G.ndata["node_input_embedding"][1, 0] = 2.0
G.ndata["node_input_embedding"][2, 0] = 3.0
# Connect all nodes
G.add_edges([1,2], 0)
G.add_edges([0,2], 1)
G.add_edges([0,1], 2)
node_features_dim=1
edge_embedding_dim=2
output_node_embedding_dim=5
graphnet = GraphNetwork(node_features_dim,
edge_embedding_dim,
output_node_embedding_dim)
print(graphnet(G))
assert graphnet(G).shape[0] == n_nodes-
assert graphnet(G).shape[-1] == output_node_embedding_dim
pytorch_total_params = sum(p.numel() for p in graphnet.parameters() if p.requires_grad)
print(f'Model parameters: {pytorch_total_params}')