CUDA out of memory with RGCN in tutorial

I recently encountered the problem of CUDA out of memory in one training epoch. For each iter in one epoch, I feed a batch graphs into model and predict the answer. But I found that the GPU memory is increasing over each iter and it raise ‘CUDA out of memory’ after a few batches. I tried another method without RGCN, and the aforementioned situation did not happen. The following part is my code, borrowed and modified code from https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/4_rgcn.html .

class RGCNLayer(nn.Module):
    def __init__(self, feat_size, num_rels, out_size=512, activation=None, gated = True):
        
        super(RGCNLayer, self).__init__()
        self.feat_size = feat_size
        self.num_rels = num_rels
        self.activation = activation
        self.gated = gated

        self.weight = nn.Parameter(torch.Tensor(self.num_rels, self.feat_size, out_size))
        # init trainable parameters
        nn.init.xavier_uniform_(self.weight,gain=nn.init.calculate_gain('relu'))
        
        if self.gated:
            self.gate_weight = nn.Parameter(torch.Tensor(self.num_rels, self.feat_size, 1))
            nn.init.xavier_uniform_(self.gate_weight,gain=nn.init.calculate_gain('sigmoid'))
    
    # @torchsnooper.snoop()
    def forward(self, g):
        
        weight = self.weight
        gate_weight = self.gate_weight
        
        def message_func(edges):
            w = weight[edges.data['rel_type']]
            print(edges.data['rel_type'].shape) #1216
            print(w.shape) # 1216,512,512
            msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
            print(msg.shape) #1216,512
            print(edges.data['norm'].unsqueeze(1).shape) # 1216,1
            msg = msg * edges.data['norm'].unsqueeze(1).cuda()
            
            if self.gated:
                gate_w = gate_weight[edges.data['rel_type']]
                gate = torch.bmm(edges.src['h'].unsqueeze(1), gate_w).squeeze().reshape(-1,1)
                gate = torch.sigmoid(gate)
                msg = msg * gate
                
            return {'msg': msg}
    
        def apply_func(nodes):
            h = nodes.data['h']
            h = self.activation(h)
            print('gg')
            return {'h': h}
        # gpu_tracker.track()
        g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)
        # gpu_tracker.track()
class RGCNModel(nn.Module):
    def __init__(self, h_dim, num_rels, num_hidden_layers=1, gated = True):
        super(RGCNModel, self).__init__()

        self.h_dim = h_dim
        self.num_rels = num_rels
        self.num_hidden_layers = num_hidden_layers
        self.gated = gated
        
        # create rgcn layers
        self.build_model()
       
    def build_model(self):        
        self.layers = nn.ModuleList() 
        for _ in range(self.num_hidden_layers):
            rgcn_layer = RGCNLayer(self.h_dim, self.num_rels, activation=F.relu, gated = self.gated)
            self.layers.append(rgcn_layer)
            
    #@torchsnooper.snoop()
    def forward(self, g):
        for layer in self.layers:
            layer(g)
        # gpu_tracker.track()
        rst_hidden = []
        for sub_g in dgl.unbatch(g):
            print(sub_g.ndata['h'].shape)
            rst_hidden.append(  sub_g.ndata['h']   )
        # gpu_tracker.track()
        return rst_hidden

Hi,

Your code seems to work well for me without leakage. Here is how I used your code:

import numpy as np
import dgl
import dgl.function as fn
import torch.nn as nn
import torch
import torch.nn.functional as F

class RGCNLayer(nn.Module):
    def __init__(self, feat_size, num_rels, out_size=512, activation=None, gated = True):

        super(RGCNLayer, self).__init__()
        self.feat_size = feat_size
        self.num_rels = num_rels
        self.activation = activation
        self.gated = gated

        self.weight = nn.Parameter(torch.Tensor(self.num_rels, self.feat_size, out_size))
        # init trainable parameters
        nn.init.xavier_uniform_(self.weight,gain=nn.init.calculate_gain('relu'))

        if self.gated:
            self.gate_weight = nn.Parameter(torch.Tensor(self.num_rels, self.feat_size, 1))
            nn.init.xavier_uniform_(self.gate_weight,gain=nn.init.calculate_gain('sigmoid'))

    # @torchsnooper.snoop()
    def forward(self, g):

        weight = self.weight
        gate_weight = self.gate_weight

        def message_func(edges):
            w = weight[edges.data['rel_type']]
            #print(edges.data['rel_type'].shape) #1216
            #print(w.shape) # 1216,512,512
            msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
            #print(msg.shape) #1216,512
            #print(edges.data['norm'].unsqueeze(1).shape) # 1216,1
            msg = msg * edges.data['norm'].unsqueeze(1).cuda()

            if self.gated:
                gate_w = gate_weight[edges.data['rel_type']]
                gate = torch.bmm(edges.src['h'].unsqueeze(1), gate_w).squeeze().reshape(-1,1)
                gate = torch.sigmoid(gate)
                msg = msg * gate

            return {'msg': msg}

        def apply_func(nodes):
            h = nodes.data['h']
            h = self.activation(h)
            #print('gg')
            return {'h': h}
        # gpu_tracker.track()
        g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)
        # gpu_tracker.track()

class RGCNModel(nn.Module):
    def __init__(self, h_dim, num_rels, num_hidden_layers=1, gated = True):
        super(RGCNModel, self).__init__()

        self.h_dim = h_dim
        self.num_rels = num_rels
        self.num_hidden_layers = num_hidden_layers
        self.gated = gated

        # create rgcn layers
        self.build_model()

    def build_model(self):
        self.layers = nn.ModuleList()
        for _ in range(self.num_hidden_layers):
            rgcn_layer = RGCNLayer(self.h_dim, self.num_rels, activation=F.relu, gated = self.gated)
            self.layers.append(rgcn_layer)

    #@torchsnooper.snoop()
    def forward(self, g):
        for layer in self.layers:
            layer(g)
        # gpu_tracker.track()
        rst_hidden = []
        for sub_g in dgl.unbatch(g):
            #print(sub_g.ndata['h'].shape)
            rst_hidden.append(  sub_g.ndata['h']   )
        # gpu_tracker.track()
        return rst_hidden

graphs = []
for _ in range(10 * 50):
    g = dgl.DGLGraph((np.random.randint(0, 100, (200,)), np.random.randint(0, 100, (200,))))
    g.ndata['h'] = torch.randn(g.number_of_nodes(), 512).cuda()
    g.edata['rel_type'] = torch.randint(0, 10, (200,)).cuda()
    g.edata['norm'] = torch.randn(200).cuda()
    graphs.append(g)

for i in range(50):
    bg = dgl.batch(graphs[i*10:(i+1)*10])
    rgcn = RGCNModel(512, 10, 2).cuda()
    opt = torch.optim.Adam(rgcn.parameters())
    h = rgcn(bg)
    loss = torch.cat(h, 0).sum()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print('{:.1f} MiB'.format(torch.cuda.max_memory_allocated() / 1000000))

Could you test-run it? If it also works for you, then the problem may not be in the RGCN model itself.

1 Like

Thanks for your quick reply. I tried your code and it works. I noticed that you supposed the number of the edges is 200. When I randomly initialize the number of the edges in [0,1000], the memory is continuously increasing. My modified code and the print log are shown as follows.

graphs = []
for _ in range(10 * 50):
    num_e = np.random.randint(200, 500)
    a = np.random.randint(0, 100, (num_e,))
    b = np.random.randint(0, 100, (num_e,))
    
    while 1:
        if a.__contains__(0) or b.__contains__(0):
            break
        else:
            # print(a)
            # print(b)
            # input()
            a = np.random.randint(0, 100, (num_e,))
            b = np.random.randint(0, 100, (num_e,))
    g = dgl.DGLGraph((a, b))
    print(g)
    print(_)
    g.ndata['h'] = torch.randn(g.number_of_nodes(), 512).cuda()
    g.edata['rel_type'] = torch.randint(0, 10, (num_e,)).cuda()
    g.edata['norm'] = torch.randn(num_e).cuda()
    graphs.append(g)
print('{:.1f} MiB'.format(torch.cuda.max_memory_allocated() / 1000000))

for i in range(500):
    bg = dgl.batch(graphs[i*1:(i+1)*1])
    rgcn = RGCNModel(512, 10, 2).cuda()
    opt = torch.optim.Adam(rgcn.parameters())
    h = rgcn(bg)
    loss = torch.cat(h, 0).sum()
    #print(loss.item())
    opt.zero_grad()
    loss.backward()

    opt.step()

    print('{:.1f} MiB'.format(torch.cuda.max_memory_allocated() / 1000000))
1964.0 MiB
1964.0 MiB
1964.0 MiB
2229.4 MiB
2229.4 MiB
2229.4 MiB
2229.4 MiB
2229.4 MiB
2235.0 MiB
2461.5 MiB
2461.5 MiB
2461.5 MiB
2461.5 MiB
2571.7 MiB
2571.7 MiB
2571.7 MiB
2571.7 MiB
2571.7 MiB
2949.1 MiB
2949.1 MiB
2949.1 MiB
2949.1 MiB
2949.1 MiB
2949.1 MiB
3100.0 MiB
3100.0 MiB
3165.5 MiB
3165.5 MiB
3165.5 MiB
3165.5 MiB
3165.5 MiB
3165.5 MiB
3165.5 MiB
3220.6 MiB
3220.6 MiB
3363.5 MiB
3363.5 MiB
3363.5 MiB
3363.5 MiB
3363.5 MiB
3363.5 MiB
3363.5 MiB
3363.5 MiB
3363.5 MiB
3972.5 MiB
3972.5 MiB
3972.5 MiB

I noticed that if batch_size = 1, namely bg = dgl.batch(graphs[i*1:(i+1)*1]), the memory will increase continuously in each iter.

In my case it increases for a while then stops at the maximum. So I would say the increase is not due to leakage but simply because of the varying input data size, since batching different graphs will likely yield different number of nodes and edges.

Perhaps your “out of memory” error is because of the large hidden dimensionality. Could you reduce it (e.g. to 256)?

Is there any solution to remove the graphs from the GPU? Reducing the dimension is not a good method for me, because I have 70k graphs with average of 100 nodes and 1k edges. These graphs’ sizes are basically different. The GPU memory will blow up after a few iters. In my 12GB RTX 2080Ti, it raises “CUDA out of memory” after about 10 iters with batch size 4.

I see. So you actually put all the graphs on GPU before training? I guess you can then put all the graphs on CPU, and copy each batch on GPU during training on-demand, like this:

graphs = []
import tqdm
for _ in tqdm.trange(10 * 50):
    num_e = np.random.randint(200, 500)
    a = np.random.randint(1, 100, (num_e,))
    b = np.random.randint(1, 100, (num_e,))
    g = dgl.DGLGraph((a, b))
    g.ndata['h'] = torch.randn(g.number_of_nodes(), 256)  # here
    g.edata['rel_type'] = torch.randint(0, 10, (num_e,))  # here
    g.edata['norm'] = torch.randn(num_e)  # here
    graphs.append(g)

rgcn = RGCNModel(256, 10, 2).cuda()
for _ in range(10):
    for i in range(50):
        bg = dgl.batch(graphs[i*10:(i+1)*10]).to(torch.device('cuda'))  # here
        opt = torch.optim.Adam(rgcn.parameters())
        h = rgcn(bg)
        loss = torch.cat(h, 0).sum()
        opt.zero_grad()
        loss.backward()
        opt.step()
        print('{:.1f} MiB'.format(torch.cuda.max_memory_allocated() / 1000000))

I add a clear command in the last of one iter, like this:

rgcn = RGCNModel(256, 10, 2).cuda()
for _ in range(10):
    for i in range(50):
        bg = dgl.batch(graphs[i*10:(i+1)*10]).to(torch.device('cuda'))  # here
        opt = torch.optim.Adam(rgcn.parameters())
        h = rgcn(bg)
        loss = torch.cat(h, 0).sum()
        opt.zero_grad()
        loss.backward()
        opt.step()

        bg.clear() #here

        print('{:.1f} MiB'.format(torch.cuda.max_memory_allocated() / 1000000))

So it would not blow up. But for my own dataset, the command bg.clear() raised an error below:

dgl._ffi.base.DGLError: [18:43:55] /opt/dgl/include/dgl/immutable_graph.h:576: Clear isn't supported in ImmutableGraph

By the way, I load my dataset in this way:

from dgl.data.utils import load_graphs
graphs = load_graphs("../data/graph.bin")[0]

Does this mean that the graphs loaded are of ImmutableGraph type? If I want to clear a graph, I need to create the graph in the program again?

You don’t have to call bg.clear() since bg will be recycled once another batch comes in. Does your code not work if you don’t have this statement? If so, what is your DGL version?

Yes. If delete the statement bg.clear(), your example still run out of memory. My DGL verison is '0.4.3post2'.

Very strange as I cannot reproduce the problem. Changing the hidden dimensionality to 512 and batch size to 10 exhausts my GPU with 8GB memory right at the beginning, and that’s because the computation itself takes a lot of memory. Hidden dimensionality of 512 and batch size of 4 runs fine as it peaks at 3874.8 MiB.

Could you try the nightly version?

pip install --pre --upgrade dgl
1 Like

Now I update to v0.5.0. The growth of GPU memory does have an upper limit. Once I have a larger graph, the memory will reach a new upper limit. When I process a graph with 1190 nodes and 13476 edges, it raise “CUDA out of memory” again. It looks like I need to give up these very large graphs.

Anyway, thanks for your reply and wish you a nice weekend.

How many relation type do you have?
the nn. RelGraphConv has a low_mem implementation support which can save memory.

1 Like

Thanks for your attention. I have 58 relations. Then I will check the low_mem implementation.