Loss Backward Taking 30x More Time than Forward Pass

I am building Graph Neural Network Based NER Code. For this task I am using GNN-BERT-CRF Architecture. While training Forward Pass taking few milliseconds but Backward Pass takes on an average 30seconds to Execute.

class GATBertCRF(torch.nn.Module):
    def __init__(self, num_classes, input_dim, hidden_dim, output_dim, edge_input_dim, edge_output_dim, multihead_num = 1, dropout = 0.6):
        super(GATBertCRF, self).__init__()

        self.gat = GATLayer(input_dim, hidden_dim, output_dim, edge_input_dim, edge_output_dim, multihead_num, dropout)
        self.mlp = torch.nn.Linear(output_dim+EMBEDDING_SHAPE, num_classes)
        self.crf = CRF(num_classes).to(device)
        self.num_classes = num_classes
        self.input_dim = input_dim
  
    def neg_log_likelihood(self, g):
        h = self.gat(g)
                
        max_seq_len = NODE_TEXT_MAX_LENGTH
        num_nodes   = g.nodes().shape[0]       
        num_tags    = self.num_classes
        
        embeddings = g.ndata['embedding']
        h = torch.unsqueeze(h, dim = 1)
        h = h.repeat(1, embeddings.shape[1], 1)
        emissions = self.mlp(torch.cat([embeddings, h], dim = 2))
        emissions = torch.permute(emissions, (1,0,2))
        
        # For training model
        tags = g.ndata['tags']
        masks = g.ndata['masks']
#         if tags is not None:
        emissions   = emissions.to(device)
        labels      = tags.T.to(device)
        mask        = masks.T.to(device)  

        log_likelihood = self.crf(emissions, labels, mask)            
        return -1 * log_likelihood
        
    def forward(self, g):
        h = self.gat(g)
                
        max_seq_len = NODE_TEXT_MAX_LENGTH
        num_nodes   = g.nodes().shape[0]       
        num_tags    = self.num_classes
        
        embeddings = g.ndata['embedding']
        h = torch.unsqueeze(h, dim = 1)
        h = h.repeat(1, embeddings.shape[1], 1)
        emissions = self.mlp(torch.cat([embeddings, h], dim = 2))
        emissions = torch.permute(emissions, (1,0,2))

        # For prediction
        sequence_of_tags = self.crf.decode(emissions)
        return sequence_of_tags

This is the Model. Can you tell me why Loss Backward taking Time?

Hi @debanjanm ,

I think it requires domain specific knowledge to figure out the problem. You can try to use torch profiler to see which module takes a long time.

Well. I did. evaluate_function: AddmmBackward0 and evaluate_function: TBackward0 is taking huge amount of time.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.