I have graphs (~ 2M) with number of nodes 3000~4000 and edges ~ 100000, and if I double the batch size and used 2 GPUs it complains that CUDA out of memory. Here is my code
import os
import h5py
import time
import logging
import numpy as np
from random import randint
from typing import Optional
from operator import itemgetter
from datetime import datetime
# torch impors
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch_cluster import knn_graph
# deep-graph-library imports
import dgl
from dgl.data import DGLDataset, split_dataset
from dgl.dataloading import GraphDataLoader
# framework imports
from utils import GnnDataset
class DyGCNN(torch.nn.Module):
def __init__(self, in_feats=8, h_feats=8, num_classes=2):
super(DyGCNN, self).__init__()
self.conv1 = dgl.nn.pytorch.conv.EdgeConv(in_feats, h_feats*2)
self.conv2 = dgl.nn.pytorch.conv.EdgeConv(h_feats*2, h_feats*4)
self.conv3 = dgl.nn.pytorch.conv.EdgeConv(h_feats*4, h_feats*8)
self.conv4 = dgl.nn.pytorch.conv.EdgeConv(h_feats*8, h_feats*16)
self.clf = torch.nn.Linear(h_feats*16, num_classes)
self.actv = torch.nn.Tanh()
def forward(self, g, in_feat):
h = self.actv(self.conv1(g, in_feat))
h = self.actv(self.conv2(g, h))
h = self.actv(self.conv3(g, h))
h = self.actv(self.conv4(g, h))
g.ndata['h'] = h
h = dgl.mean_nodes(g, 'h')
return self.clf(h)
def save_state(iteration, network, optimz, vloss, name):
torch.save({
'epoch': iteration,
'model_state_dict': network.state_dict(),
'optimizer_state_dict': optimz.state_dict(),
'loss': vloss,
}, name)
def init_process_group(world_size, rank):
dist.init_process_group(
backend='nccl',
init_method='tcp://127.0.0.1:12345',
world_size=world_size,
rank=rank)
def init_model(seed, device):
torch.manual_seed(seed)
model = DyGCNN().to(device)
model = DDP(model, device_ids=[device], output_device=device)
return model
class CONFIG:
pass
config = CONFIG()
config.h5path = 'input.h5'
config.geo_path = 'geometry.npz'
config.batch_size_test = 64
config.batch_size_train = 64
config.batch_size_val = 64
config.lr = 0.005
config.step_size = 1
config.gamma = 0.1
config.gpu = [6]
config.epochs = 10
config.k_nn = 30
config.report = 50
config.dump_path = 'outputs'
dataset = GnnDataset(config.h5path, config.geo_path, config.k_nn)
def cleanup():
dist.destroy_process_group()
def get_dataloaders(dataset, seed, batch_size=config.batch_size_train):
# Use a 70:10:20 train-val-test split
train_set, val_set, test_set = split_dataset(dataset,
frac_list=[0.7, 0.1, 0.2],
shuffle=True,
random_state=seed)
train_loader = GraphDataLoader(
train_set, use_ddp=True, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader = GraphDataLoader(val_set, batch_size=batch_size)
test_loader = GraphDataLoader(test_set, batch_size=batch_size)
return train_loader
def main(rank, world_size, seed=0):
init_process_group(world_size, rank)
# Assume the GPU ID to be the same as the process ID
device = torch.device('cuda:{:d}'.format(config.gpu[rank]))
print("Running main worker function on device: {}".format(device))
torch.cuda.set_device(device)
model = init_model(seed, device)
# define the optimizer and the loss function
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
criterion = torch.nn.CrossEntropyLoss()
softmax = torch.nn.LogSoftmax(dim=1)
trainloader = get_dataloaders(dataset, seed)
total_step = len(trainloader)
filename = 'model.pt'
if rank == 0:
print('Number of training iterations per epoch:', total_step)
for epoch in range(config.epochs):
model.train()
# The line below ensures all processes use a different
# random ordering in data loading for each epoch.
trainloader.set_epoch(epoch)
if rank == 0:
print('Epoch', int(epoch+1), 'Starting @',
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
# ---------------------------------------------------------------------------------
# training data-set
# ---------------------------------------------------------------------------------
for i_batch, (graph, labels) in enumerate(trainloader):
graph = graph.to(device)
labels = labels.to(device)
feats = graph.ndata.pop('x')
linear_model_out = model(graph, feats)
loss = criterion(linear_model_out, labels)
softmax_out = softmax(linear_model_out)
accuracy = (torch.max(torch.exp(softmax_out), 1)[
1] == labels).int().sum().item()/len(labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if rank == 0:
if i_batch % config.report == 0:
save_state(epoch, model, optimizer, loss.item(), filename)
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}, Accuracy: {:.4f}'.format(
epoch+1, config.epochs, i_batch+1, total_step, loss.item(), accuracy))
cleanup()
if __name__ == "__main__":
start = datetime.now()
print('Start training at: ', start)
# ****************************************************************************************************
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '6006'
num_gpus = len(config.gpu)
mp.spawn(main, args=(num_gpus,), nprocs=num_gpus, join=True)
# ****************************************************************************************************
print("Training completed in: " + str(datetime.now() - start))