Hi, I’m having some memory errors when training a GCN model on a gpu, the model runs fine for about 25 epochs and then crashes. I think the problem might be related on how I handle the batches, or in the training loop.
optimizer = torch.optim.Adam(model.parameters(),
lr=config_dict['learning_rate'],
)
criterion = nn.MSELoss()
print("Start training...")
start = time.time()
for epoch in range(FLAGS.n_epochs):
loss_list = []
# TRAIN
model.train()
for batch, data in enumerate(train_dataloader):
graph, labels = data
graph.ndata['features'] = graph.ndata['features'].to(device)
labels = labels.to(device)
logits = model.forward(graph, graph.ndata['features'])
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(float(loss.item()))
torch.cuda.empty_cache()
loss_data = np.array(loss_list).mean()
experiment.log_metric("mse", loss_data, step=epoch)
# VALIDATE
val_loss_list = []
model.eval()
for val_batch, val_data in enumerate(val_dataloader):
val_graph, val_labels = val_data
val_graph.ndata['features'] = val_graph.ndata['features'].to(device)
val_labels = val_labels.to(device)
logits = model.forward(val_graph, val_graph.ndata['features'])
val_loss = criterion(logits, val_labels)
val_loss_list.append(float(val_loss.item()))
torch.cuda.empty_cache()
val_loss_data = np.array(val_loss_list).mean()
print("Epoch {:d} | Train Loss: {:.4f} | Validation loss: {:.4f}".
format(epoch, loss_data, val_loss_data))