Resolving out of memory issue

Hi, I have a GCN layer defined as below. I’m processing a large graph (~300k entities and ~700k edges) and run out of memory on GPU. While checking the GPU usage at each line I noticed that the propagate function allocates a large amount of memory, that is not freed up after returning to the main training loop. I also noticed that there was a tensor of dimension [#nodes, #edges] allocated while calling update_all. Is this necessary to be stored much after the messages have been propagated?
Any suggestions or tips would be appreciated.

class GCNLayer(RGCNLayer):
    def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=True,
                 activation=None, self_loop=False, dropout=0.2, bert=False, bert_trainable=False):
        super(WGCNLayer, self).__init__(in_feat, out_feat, bias,
                                        activation, self_loop=self_loop,
                                        dropout=dropout, bert=bert, bert_trainable=bert_trainable)
        self.num_rels = num_rels
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.weight = Parameter(torch.FloatTensor(self.in_feat, self.out_feat))
        self.weight_rel = torch.nn.Embedding(self.num_rels, 1, padding_idx=0)
        self.bn = torch.nn.BatchNorm1d(self.out_feat)
        xavier_normal_(self.weight.data)
        #stdv = 1. / np.sqrt(self.weight.size(1))
        #self.weight.data.uniform_(-stdv, stdv)

    def msg_func(self, edges):
        """
        Compute messages only from source node features
        """
        #alpha = self.weight_rel.index_select(0, edges.data['type'].squeeze())
        #edge_types = edges.data['type'].squeeze().to('cpu')
        edge_types = edges.data['type'].squeeze()
        alpha = self.weight_rel(edge_types)
        #edges.src['h'] = edges.src['h'].to("cpu")
        #self.weight = self.weight.cpu()
        node = torch.mm(edges.src['h'], self.weight)
        #node = node.cuda()
        msg = alpha.expand_as(node) * node
        #msg = torch.bmm(node, weight).view(-1, self.out_feat)
        return {'msg': msg}

    def propagate(self, g):
        g.update_all(self.msg_func, fn.sum(msg='msg', out='h'), self.apply_func)


    def apply_func(self, nodes):
        return {'h': nodes.data['h'] * nodes.data['norm']}```

Hi, the reason you see [#nodes, #edges] allocated is that you are using a custom message function and a builtin-reduce function, in this case DGL will use degree bucketing to parallize the computation, which is not most efficient in many cases.
However, the message function should be as simple as possible, for example, it seems you don’t have to do

node = torch.mm(edges.src['h'], self.weight)

inside the message function, you could preprocess it in advance, for example:

g.ndata['h_mm'] = torch.mm(g.ndata['h'], self.weight)

As for alpha, I think you should also process it in advance:

g.edata['alpha'] = self.weight_rel(g.edata['type'])

Then I think what you actually did is a src_mul_edge operation, however DGL has implemented a built-in function for this yet, try the following code:

g.update_all(fn.src_mul_edge(src='h_mm', edge='alpha', out='msg'), 
             fn.sum(msg='msg', out='h'), self.apply_func)

This would be much faster and GPU memory friendly. (btw, in DGL 0.3 it could be further faster).

The key is to use as much built-in function as possible so that our scheduler could detect where could be optimized.
I hope this helps.

1 Like

Thanks a lot @zihao! I tried your suggestion and still run out of memory on a 16GB GPU. I noticed that the large [#nodes, #edges] is no longer allocated with the built-in function but there is now a [#nodes, #nodes] tensor allocated (probably for the adjacency matrix) which occupies a lot of memory. Is there a reason why a sparse implementation for this matrix is not used?
Further, is this tensor required for autograd? If not, is it possible to free it up after the forward propagation?

Hi, I think the [#nodes, #nodes] tensor may result from default backward pass of PyTorch spmm(for y = x * spmat, PyTorch default spmm would compute d_spmat by first compute X^T * dy then take the items corresponding to non-zero terms in spmat, this is not efficient in terms of both speed and memory).
Could you please tell me the version of DGL you are using? In DGL 0.2 I think the problem has been resolved.

I am using DGL version 0.2, and I’ve also tried installing from source.

I tried running the following code and I did not see DGL creating a [# nodes, # nodes] matrix:

import torch
import dgl
import networkx

G = networkx.path_graph(1000000)    # so that there's no way to create a node x node matrix
g = dgl.DGLGraph(G)

g.ndata['h'] = torch.randn(g.number_of_nodes(), 30).cuda()
W = torch.randn(30, 50).cuda().requires_grad_()
g.ndata['h_mm'] = g.ndata['h'] @ W

g.edata['a'] = torch.randint(0, 10, (g.number_of_edges(),))
E = torch.randn(10, 50).cuda().requires_grad_()
g.edata['a_r'] = E[g.edata['a']]

g.update_all(
        message_func=dgl.function.src_mul_edge('h_mm', 'a_r', 'm'),
        reduce_func=dgl.function.sum('m', 'm_sum'),
        apply_node_func=lambda nodes: {'h_mm': nodes.data['h_mm'] + nodes.data['m_sum']})

loss = g.ndata['h_mm'].sum()
loss.backward()
print(W.grad.sum().item(), E.grad.sum().item())

Does the snippet above match your scenario?

Hi,

Thanks, the snippet does match my scenario closely. However, I still haven’t been able to get around this issue since my model has a lot more parameters other than the ones used by DGL.
Also, I tried tracking the tensors living on GPU memory with your snippet using this tool (https://github.com/Oldpan/Pytorch-Memory-Utils) and it seems like a large [# nodes, ~2*# nodes] is still allocated. I can’t vouch for how accurate this toolkit is though, since it does also list the memory occupied as 799999M whereas I don’t run out of memory running your snippet. Would appreciate any pointers to resolve this.

I placed the tracker function just before and after the update_all() function.

GPU Memory Track | 20-May-19-18:22:12 | Total Used Memory:1417.7 Mb

+ | 1 * Size:(30, 50)             | Memory: 0.006 M | <class 'torch.Tensor'>
+ | 1 * Size:(1999998, 50)        | Memory: 399.99 M | <class 'torch.Tensor'>
+ | 1 * Size:(1000000, 50)        | Memory: 200.0 M | <class 'torch.Tensor'>
+ | 1 * Size:(10, 50)             | Memory: 0.002 M | <class 'torch.Tensor'>
+ | 1 * Size:(1000000, 30)        | Memory: 120.0 M | <class 'torch.Tensor'>

At __main__ <module>: line 22                        Total Used Memory:1417.7 Mb

+ | 2 * Size:(1000000, 50)        | Memory: 400.0 M | <class 'torch.Tensor'>
+ | 1 * Size:(1000000, 1999998)   | Memory: 799999 M | <class 'torch.Tensor'>
+ | 3 * Size:(1999998,)           | Memory: 23.999 M | <class 'torch.Tensor'>
- | 1 * Size:(1000000, 50)        | Memory: 200.0 M | <class 'torch.Tensor'> 

At __main__ <module>: line 30                        Total Used Memory:1417.7 Mb

Hi,

Could you print the shape of this LoC?

msg = alpha.expand_as(node) * node

Try print the shape of node, alpha.expand_as(node) and msg.

The shapes of node, alpha.expand_as(node) and msg are all the same. Their shape is [#edges, hidden_dim]. (in my case, [60000, 200])

It looks like the tool does not reflect the actual memory consumption for sparse tensors. You could see here that the script simply calculates the tensor size by x.size(). However, for sparse tensor, it should count the actual number of non-zero values.

>>> import torch
>>> i = torch.LongTensor([[0, 1, 1], [2, 0, 2]])  # only three nnz values
>>> v = torch.LongTensor([3, 4, 5])
>>> x = torch.sparse.FloatTensor(i, v, torch.Size([1000, 1000]))
>>> x.size()  # size gives the dense shape
torch.Size([1000, 1000])