I have implemented a simple graph network as described in https://arxiv.org/abs/1806.01261. However, my implementation seems to be much slower than a model built using GraphConv when training on the same dataset and with a similar number of model parameters. Any idea on how to speed it up?
Thanks in advance
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from fcn import FullyConnectedModel
class GraphLayer(nn.Module):
def __init__(self,
input_node_dim,
edge_dim,
output_node_dim,
n_hidden_edges=3,
n_hidden_nodes=3,
dim_hidden_edges=512,
dim_hidden_nodes=512,
):
"""
Args:
input_node_dim: dimension of the input node emebeddings (featues for first layer, other output previou
s)
edge_dim: dimension of the edge embeddings.
output_node_dim: dimension of the output node emebeddings
n_hidden_edges: number of hidden layers for the MLP defining the edge embedding
n_hidden_nodes: number of hidden layers for the MLP defining the node embedding
dim_hidden_edges: number of neurons per layer in the MLP defining the edge embedding
dim_hidden_nodes: number of neurons per layer in the MLP defining the node embedding
"""
super().__init__()
self.input_node_dim = input_node_dim
self.edge_dim = edge_dim
self.output_node_dim = output_node_dim
# Initialize MLPs (One for edge embedding, another one for node embedding)
self.mlp_edges = FullyConnectedModel(n_features = int(self.input_node_dim*2),
n_hidden_layers=n_hidden_edges,
n_hidden_dim=dim_hidden_edges,
n_output_dim=self.edge_dim)
self.mlp_nodes = FullyConnectedModel(n_features = self.edge_dim,
n_hidden_layers=n_hidden_nodes,
n_hidden_dim=dim_hidden_nodes,
n_output_dim=self.output_node_dim)
def edge_update(self, edges):
"""
Args:
edges: batch of edges
Returns:
edge embedding dictionary
"""
edge_features = torch.cat([edges.src['node_input_embedding'], edges.dst['node_input_embedding']],
dim=1)
edge_features = self.mlp_edges(edge_features)
return {'edge_embedding': edge_features}
def node_update(self, nodes):
"""
Args:
nodes: batch of nodes
Returns:
node embedding dictionary
"""
sum_edge_embeddings = torch.sum(nodes.mailbox["edge_embedding"], dim=1)
node_embedding = self.mlp_nodes(sum_edge_embeddings)
return {'node_output_embedding': node_embedding}
def forward(self, graph):
"""
Args:
graph: dgl grpah
Returns:
Node embedding
"""
graph.update_all(message_func=self.edge_update,
reduce_func=self.node_update)
return graph.ndata['node_output_embedding']
class GraphNetwork(nn.Module):
def __init__(self,
node_features_dim,
edge_dim,
output_node_dim,
n_classes,
n_hidden_edges=3,
n_hidden_nodes=3,
dim_hidden_edges=128,
dim_hidden_nodes=128,
):
'''
Args:
node_features_dim: number of input features per node
edge_dim: size of the embedding of the edges connecting nodes
output_node_dim: size of the output node embedding
n_hidden_edges: number of hidden layers for the MLP defining the edge embedding
n_hidden_nodes: number of hidden layers for the MLP defining the node embedding
dim_hidden_edges: number of neurons per layer in the MLP defining the edge embedding
dim_hidden_nodes: number of neurons per layer in the MLP defining the node embedding
'''
super().__init__()
self.graph_layer_1 = GraphLayer(node_features_dim,
edge_dim,
output_node_dim,
n_hidden_edges=n_hidden_edges,
n_hidden_nodes=n_hidden_nodes,
dim_hidden_edges=dim_hidden_edges,
dim_hidden_nodes=dim_hidden_nodes,
)
self.graph_layer_2 = GraphLayer(output_node_dim,
edge_dim,
n_classes,
n_hidden_edges=n_hidden_edges,
n_hidden_nodes=n_hidden_nodes,
dim_hidden_edges=dim_hidden_edges,
dim_hidden_nodes=dim_hidden_nodes,
)
def forward(self, graph, features):
'''
Args:
graph: dgl graph
Returns:
logits (per node prediction)
'''
graph.ndata['node_input_embedding'] = features
logits = self.graph_layer_1(graph)
graph.ndata['node_input_embedding'] = logits
logits = self.graph_layer_2(graph)
return logits