GATv2Conv OOM, Help Needed

I try to train a GATv2Conv on my 1070 (8Gb) but i have OOM every time i try to run it.
I have tried a lot of memory optimisation but i reach a state where i m out of idea now, so i m asking for help.
I don’t undestund why it took so much memory to train.
Is someone can help to make this model work on my GPU ?

System info :
torch 2.2.1+cuda12.1
dgl 2.1.0+cuda12.1
ubuntu 22.04
Nvidia Driver 550.14 with cuda 12.4 on my machine

The script is running like this :

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python

And i have tested a lot of memory trick like free cuda cache as much i can.
Allocate all my dataset on cpu device rather than gpu and doing .to(“cuda”) only when it needed during training.
All this trick help a little bit i can do some epoch but still crash.

import os
import torch 
import torch.nn as nn
import torch.nn.functional as F
#import torch.optim.lr_scheduler as lr_scheduler
from import DGLDataset
import af_reader_py
import statistics
#from torch.cuda.amp import autocast, GradScaler
import dgl
from sklearn.preprocessing import StandardScaler
from dgl.nn import GATv2Conv

af_data_root = "../af_dataset/"
result_root = "../af_dataset/all_result/"
task = "DS-ST"
MAX_ARG = 200000
v = os.environ.get("PYTORCH_CUDA_ALLOC_CONF")

def transfom_to_graph(label_path, n):
    f = open(label_path, 'r')
    data =
    target = [0.]*n
    for n in data.split(','):
        if n == '':
        target[int(n)] = 1.0
    return torch.tensor(target, requires_grad=False, device=device)
def light_get_item(af_name, af_dir, label_dir, year):
    af_path = os.path.join(af_dir,af_name)
    label_path = os.path.join(label_dir,af_name)
    att1, att2, nb_el = af_reader_py.reading_file_for_dgl(af_path)
    if nb_el > MAX_ARG:
        return None, None, None, 1000000
    target = transfom_to_graph(label_path, nb_el)
    graph = dgl.graph((torch.tensor(att1),torch.tensor(att2)), num_nodes = nb_el, device=device)
    #print("L : ",nb_el, " ", graph.number_of_nodes())
    graph = dgl.add_self_loop(graph)
    inputs = torch.load(af_data_root+"all_features/"+year+"/"+af_name+".pt", map_location=device)
    return graph, inputs, target, nb_el

def get_item(af_name, af_dir, label_dir, year):
    print(af_name," ",  af_dir, " ", label_dir)
    af_path = os.path.join(af_dir,af_name)
    label_path = os.path.join(label_dir,af_name)
    att1, att2, nb_el = af_reader_py.reading_file_for_dgl(af_path)
    if nb_el > MAX_ARG:
        return None, None, None, 1000000
    target = transfom_to_graph(label_path, nb_el)
    graph = dgl.graph((torch.tensor(att1),torch.tensor(att2)), num_nodes = nb_el, device=device)
    raw_features = af_reader_py.compute_features(af_path, 10000, 0.00001 )
    scaler = StandardScaler()
    features = scaler.fit_transform(raw_features)
    inputs = torch.tensor(features, dtype=torch.float16, requires_grad=False).to(device), af_data_root+"all_features/"+year+"/"+af_name+".pt")
    #print("N : ",nb_el," ", graph.number_of_nodes())
    graph = dgl.add_self_loop(graph)
    return graph, inputs, target, nb_el

class TrainingGraphDataset(DGLDataset):
    def __init__(self, af_dir, label_dir):
        super().__init__(name="Af dataset")
        self.label_dir = label_dir
        self.af_dir = af_dir
    def __len__(self):
        return len(self.graphs)
    def process(self):
        #list_year_dir = ["2017", "2023"]
        list_year_dir = ["2017"]
        self.af_dir = af_data_root+"dataset_af"
        self.label_dir = result_root+"result_"+task
        self.graphs = []
        self.labels = []
        list_unique_file = []
        for year in list_year_dir:
            iter = os.listdir(self.label_dir +"_"+ year)
            for f in iter:
                true_name = f.replace(".apx", "")
                true_name = true_name.replace(".af", "")
                true_name = true_name.replace(".tgf", "")
                true_name = true_name.replace(".old", "")
                #if f != "afinput_exp_acyclic_indvary3_step7_batch_yyy09.apx":
                #    continue
                #if True:
                if true_name not in list_unique_file:
                    #print(f," ",  self.label_dir+"_"+year )
                    if os.path.exists(af_data_root+"all_features/"+year+"/"+f+".pt"):
                        graph, features, label, nb_el = light_get_item(f, self.af_dir+"_"+year+"/", self.label_dir+"_"+year+"/", year)
                        graph, features, label, nb_el = get_item(f, self.af_dir+"_"+year+"/", self.label_dir+"_"+year+"/", year)
                    if nb_el < MAX_ARG:
                        graph.ndata["feat"] = features
                        graph.ndata["label"] = label

    def __getitem__(self, idx:int):
        return self.graphs[idx]
class ValisationDataset(DGLDataset):
    def __init__(self, af_dir, label_dir):
        super().__init__(name="Validation Dataset")
        self.label_dir = label_dir
        self.af_dir = af_dir
    def __len__(self):
        return len(self.graphs)
    def process(self):
        #list_year_dir = ["2017", "2023"]
        list_year_dir = ["2023"]
        self.af_dir = af_data_root+"dataset_af"
        self.label_dir = result_root+"result_"+task
        self.graphs = []
        self.labels = []
        list_unique_file = []
        for year in list_year_dir:
            iter = os.listdir(self.label_dir +"_"+ year)
            for f in iter:
                true_name = f.replace(".apx", "")
                true_name = true_name.replace(".af", "")
                true_name = true_name.replace(".tgf", "")
                true_name = true_name.replace(".old", "")
                if true_name not in list_unique_file:
                    if os.path.exists(af_data_root+"all_features/"+year+"/"+f+".pt"):
                        graph, features, label, nb_el = light_get_item(f, self.af_dir+"_"+year+"/", self.label_dir+"_"+year+"/", year)
                        graph, features, label, nb_el = get_item(f, self.af_dir+"_"+year+"/", self.label_dir+"_"+year+"/", year)
                    if nb_el < MAX_ARG:
                        graph.ndata["feat"] = features
                        graph.ndata["label"] = label

    def __getitem__(self, idx:int):
        return self.graphs[idx]
class GAT(nn.Module):
    def __init__(self, in_size, hid_size, out_size, heads):
        self.gat_layers = nn.ModuleList()
        # three-layer GAT
            GATv2Conv(in_size, hid_size, heads[0], activation=F.elu)
                hid_size * heads[0],
                hid_size * heads[1],

    def forward(self, g, inputs):
        h = inputs
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h)
            if i == 2:  # last layer
                h = h.mean(1)
            else:  # other layer(s)
                h = h.flatten(1)
        return h#.squeeze() #Remove the last dimension
device1 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"
#device = "cpu"
print("runtime : ", device)
print("runtime : ", device1)

model = GAT(11, 5, 1, heads=[6, 4, 4]).to(device1)

model_path = "v3-"+task+"-11-gatv2.pth"

loss = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print("Loading Data...")
af_dataset = TrainingGraphDataset(af_data_root+"dataset_af/", af_data_root+"result/")
data_loader = dgl.dataloading.GraphDataLoader(af_dataset, batch_size=8, shuffle=True)

print("Start training")
#scaler = GradScaler()

for epoch in range(400):
    tot_loss = []
    tot_loss_v = 0
    for graph in data_loader:
        inputs = graph.ndata["feat"].to(device1)
        label = graph.ndata["label"].to(device1)
        graph_cdn =
        out = model(graph_cdn, inputs)
        predicted = (torch.sigmoid(out.squeeze())).float()
        del inputs
        del graph_cdn
        losse = loss(predicted, label)
        del label
        tot_loss_v += losse.item()
        if i%4 == 0:

    if epoch == 150:
        for g in optimizer.param_groups:
            g['lr'] = 0.001
    print(i, "Epoch : ", epoch," Mean : " , statistics.fmean(tot_loss), " Median : ", statistics.median(tot_loss), "loss : ", tot_loss_v), model_path)

GNN’s memory usage is more related to the input graph, have you tried training smaller dataset? Like Cora.

Thank you for your reply.
No, i haven’t tried, do you want to me to try on Cora to ensure that it can be handle by my gpu ?

It depends on your purpose. A smaller dataset is a good choice for learning GNN. If you must train large graphs, you’ll have to do it on a CPU or a GPU with more memory.

Yeah but i have to use this Dataset,
The thing that suprise me a lot is that GATv1 work for me on my device with the same parameter but gatv2 seems to use so much memory, but when reading the paper GATv2 it should be just an update about the order how some function are executed so i was wondering why it use some much more memory while it just a fix of the v1.

I tried training a small demo using GATv1 vs GATv2. GATv2 does consume more memory. This should be related to the specific implementation of the layer. It looks like trying to train GATv2 entirely on 1070 with the dataset you need is not possible.
Maybe you could consider using uvm in cuda by modifying PyTorch? this will bring some performance loss, if you can live with it.

Yes, thank you.
On windows it seems that when the GPU memory in full, it use the “Shared Memory” that seems to be a shared memory betwenn RAM and VRAM.
Is it what you are talking about ?
I’ve never managed to activate it.

I was just surprised that even with all my dataset on the ram to free the gpu memory so 8Gb only for forward pass and backprop i reached OOM.
Should i report the huge different memory usage betwenn GAT and GATv2 on a issue ?