CUDA out of memory error training after a few epochs

Hi, I’m having some memory errors when training a GCN model on a gpu, the model runs fine for about 25 epochs and then crashes. I think the problem might be related on how I handle the batches, or in the training loop.

  optimizer = torch.optim.Adam(model.parameters(),
                            lr=config_dict['learning_rate'],
                            )
    criterion = nn.MSELoss()
    print("Start training...")
    start = time.time()
    for epoch in range(FLAGS.n_epochs):
        loss_list = []
        # TRAIN
        model.train()
        for batch, data in enumerate(train_dataloader):
            graph, labels = data
            graph.ndata['features'] = graph.ndata['features'].to(device)
            labels = labels.to(device)
            logits = model.forward(graph, graph.ndata['features'])
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_list.append(float(loss.item()))
            torch.cuda.empty_cache() 
        loss_data = np.array(loss_list).mean()
        experiment.log_metric("mse", loss_data, step=epoch)
        # VALIDATE
        val_loss_list = []
        model.eval()
        for val_batch, val_data in enumerate(val_dataloader):
            val_graph, val_labels = val_data
            val_graph.ndata['features'] = val_graph.ndata['features'].to(device)
            val_labels = val_labels.to(device)
            logits = model.forward(val_graph, val_graph.ndata['features'])
            val_loss = criterion(logits, val_labels)
            val_loss_list.append(float(val_loss.item()))
            torch.cuda.empty_cache() 
        val_loss_data = np.array(val_loss_list).mean()
        print("Epoch {:d} | Train Loss: {:.4f} | Validation loss: {:.4f}".
              format(epoch, loss_data, val_loss_data))

Hi, please use torch.no_grad in evaluation phase.

That solved it! Thank you. May I ask why is it necessary, and model.eval is not enough?

DGL uses torch.autograd.Function, and there is a reference cycle issue (https://github.com/pytorch/pytorch/issues/25340) with it that may cause memory leaking and it has not been fixed yet.

You will meet this problem if you call forward function of a module without calling backward function. We are stil trying to fix it and using torch.no_gradis a workaround.