Improve efficiency of implementation

I have implemented a simple graph network as described in https://arxiv.org/abs/1806.01261. However, my implementation seems to be much slower than a model built using GraphConv when training on the same dataset and with a similar number of model parameters. Any idea on how to speed it up?

Thanks in advance :slight_smile:

 import dgl                                                                                                         
 import torch                                                                                                       
 import torch.nn as nn                                                                                              
 import torch.nn.functional as F                                                                                    
                                                                                                                    
 from fcn import FullyConnectedModel                                                                                
                                                                                                                    
 class GraphLayer(nn.Module):                                                                                       
     def __init__(self,                                                                                             
                 input_node_dim,                                                                                    
                 edge_dim,                                                                                          
                 output_node_dim,                                                                                   
                 n_hidden_edges=3,                                                                                  
                 n_hidden_nodes=3,                                                                                  
                 dim_hidden_edges=512,                                                                              
                 dim_hidden_nodes=512,                                                                              
                 ):                                                                                                 
         """                                                                                                        
         Args:                                                                                                      
             input_node_dim: dimension of the input node emebeddings (featues for first layer, other output previou
 s)                                                                                                                 
             edge_dim: dimension of the edge embeddings.                                                            
             output_node_dim: dimension of the output node emebeddings                                              
             n_hidden_edges: number of hidden layers for the MLP defining the edge embedding                        
             n_hidden_nodes: number of hidden layers for the MLP defining the node embedding                        
             dim_hidden_edges: number of neurons per layer in the MLP defining the edge embedding                   
             dim_hidden_nodes: number of neurons per layer in the MLP defining the node embedding                   
                                                                                                                    
         """                                                                                                        
         super().__init__()                                                                                         
                                                                                                                    
         self.input_node_dim = input_node_dim                                                                       
         self.edge_dim = edge_dim                                                                                   
         self.output_node_dim = output_node_dim                                                                     
         # Initialize MLPs (One for edge embedding, another one for node embedding)                                 
         self.mlp_edges = FullyConnectedModel(n_features = int(self.input_node_dim*2),                              
                                         n_hidden_layers=n_hidden_edges,                                            
                                         n_hidden_dim=dim_hidden_edges,                                             
                                         n_output_dim=self.edge_dim)                                                
         self.mlp_nodes = FullyConnectedModel(n_features = self.edge_dim,                                           
                                         n_hidden_layers=n_hidden_nodes,                                            
                                         n_hidden_dim=dim_hidden_nodes,                                             
                                         n_output_dim=self.output_node_dim)                                         
     def edge_update(self, edges):                                                                                  
         """                                                                                                        
         Args:                                                                                                      
             edges: batch of edges                                                                                  
                                                                                                                    
         Returns:                                                                                                   
             edge embedding dictionary                                                                              
         """                                                                                                        
         edge_features = torch.cat([edges.src['node_input_embedding'], edges.dst['node_input_embedding']],          
                               dim=1)                                                                               
         edge_features = self.mlp_edges(edge_features)                                                              
         return {'edge_embedding': edge_features}
def node_update(self, nodes):                                                                                  
         """                                                                                                        
         Args:                                                                                                      
             nodes: batch of nodes                                                                                  
                                                                                                                    
         Returns:                                                                                                   
             node embedding dictionary                                                                              
         """                                                                                                        
         sum_edge_embeddings = torch.sum(nodes.mailbox["edge_embedding"], dim=1)                                    
         node_embedding = self.mlp_nodes(sum_edge_embeddings)                                                       
         return {'node_output_embedding': node_embedding}                                                           
                                                                                                                    
     def forward(self, graph):                                                                                      
         """                                                                                                        
         Args:                                                                                                      
             graph: dgl grpah                                                                                       
                                                                                                                    
         Returns:                                                                                                   
             Node embedding                                                                                         
         """                                                                                                        
         graph.update_all(message_func=self.edge_update,                                                            
                         reduce_func=self.node_update)                                                              
         return graph.ndata['node_output_embedding']
class GraphNetwork(nn.Module):                                                                                     
     def __init__(self,                                                                                             
                 node_features_dim,                                                                                 
                 edge_dim,                                                                                          
                 output_node_dim,                                                                                   
                 n_classes,                                                                                         
                 n_hidden_edges=3,                                                                                  
                 n_hidden_nodes=3,                                                                                  
                 dim_hidden_edges=128,                                                                              
                 dim_hidden_nodes=128,                                                                              
                 ):                                                                                                 
                                                                                                                    
         '''                                                                                                        
         Args:                                                                                                      
             node_features_dim: number of input features per node                                                   
             edge_dim: size of the embedding of the edges connecting nodes                                          
             output_node_dim: size of the output node embedding                                                     
             n_hidden_edges: number of hidden layers for the MLP defining the edge embedding                        
             n_hidden_nodes: number of hidden layers for the MLP defining the node embedding                        
             dim_hidden_edges: number of neurons per layer in the MLP defining the edge embedding                   
             dim_hidden_nodes: number of neurons per layer in the MLP defining the node embedding                   
         '''                                                                                                        
                                                                                                                    
         super().__init__()                                                                                         
         self.graph_layer_1 = GraphLayer(node_features_dim,                                                         
                                     edge_dim,                                                                      
                                     output_node_dim,                                                               
                                     n_hidden_edges=n_hidden_edges,                                                 
                                     n_hidden_nodes=n_hidden_nodes,                                                 
                                     dim_hidden_edges=dim_hidden_edges,                                             
                                     dim_hidden_nodes=dim_hidden_nodes,                                             
                                     )                                                                              
                                                                                                                    
         self.graph_layer_2 = GraphLayer(output_node_dim,                                                           
                                     edge_dim,                                                                      
                                     n_classes,                                                                     
                                     n_hidden_edges=n_hidden_edges,                                                 
                                     n_hidden_nodes=n_hidden_nodes,                                                 
                                     dim_hidden_edges=dim_hidden_edges,                                             
                                     dim_hidden_nodes=dim_hidden_nodes,                                             
                                     )

def forward(self, graph, features):                                                                            
         '''                                                                                                        
         Args:                                                                                                      
             graph: dgl graph                                                                                       
         Returns:                                                                                                   
             logits (per node prediction)                                                                           
         '''                                                                                                        
         graph.ndata['node_input_embedding'] = features                                                             
         logits = self.graph_layer_1(graph)                                                                         
         graph.ndata['node_input_embedding'] = logits                                                               
         logits = self.graph_layer_2(graph)                                                                         
         return logits

Hi, please try decomposing the node_update function to a built-in function dgl.function.sum and a apply node function:

import dgl.function as fn

     def forward(self, graph):                                                                                      
         """                                                                                                        
         Args:                                                                                                      
             graph: dgl grpah                                                                                       
                                                                                                                    
         Returns:                                                                                                   
             Node embedding                                                                                         
         """                                                                                                        
         graph.update_all(message_func=self.edge_update,                                                            
                         reduce_func=fn.sum('edge_embedding', 'out'),
                         apply_node_func=lambda nodes: {'node_output_embedding': self.mlp_nodes(nodes.data['out'])})                                                              
         return graph.ndata['node_output_embedding']

or

import dgl.function as fn

     def forward(self, graph):                                                                                      
         """                                                                                                        
         Args:                                                                                                      
             graph: dgl grpah                                                                                       
                                                                                                                    
         Returns:                                                                                                   
             Node embedding                                                                                         
         """                                                                                                        
         graph.update_all(message_func=self.edge_update,                                                            
                         reduce_func=fn.sum('edge_embedding', 'out'))                                                              
         return self.mlp_nodes(graph.ndata['out'])

they have the same efficiency.

BTW, we have implemented Recurrent Relational Network in our examples: https://github.com/dmlc/dgl/tree/master/examples/pytorch/rrn, which is similar to your use case.

1 Like

Thanks, that made it much faster :smiley:

1 Like