High GPU memory usage problem

(T Qri) #1

Hi, I implemented an attention-based Sequence-to-sequence model in Theano and then ported it into PyTorch.

However, the GPU memory usage in Theano is only around 2GB, while PyTorch requires almost 5GB, although it’s much faster than Theano.

Maybe it’s a trading consideration between memory and speed. But the GPU memory usage has increased by 2.5 times, that is unacceptable. I think there should be room for optimization to reduce GPU memory usage and maintaining high efficiency.

I printed out the allocated, max_allocated, cached and max_cached GPU memory after zero_grad, model forwarding, loss calculation, backward and optimizer step.

I don’t quite understand how memory usage changes according to the following output.

1). In the 0-th step, max allocated memory rapidly increased from 2154M to 2566M after “compute loss”. Is this normal?
2). In the 0-th step, after “backward”, max allocated memory increased from 2566M to 3482M again. The increased 916M should consists of Adam's exp_avg and exp_avg_sq for each parameter, so the expected increasement should be about 2*281M=562M instead of 916M. according to my understanding, all the intermediate results required to calculate derivatives are already stored in the forward pass, so there should be no more memory allocation during backward.
3). In the 0-th step, after “backward”, allocated memory is 1020M, while after optimizer step, the allocated memory is 1581M. What’s taking up the extra 1300M (1581M - 281M) here? Although I invoke model.zero_grad() in the 1-st step, the allocated memory is still 1581M, it seems that PyTorch didn’t release any GPU memory.
4). In the 1-th step, after “backward”, why max allocated memory increased 847M again, from 3482M to 4329M, as in 2) ? Note that during the training, each batch are totally the same.
5). Finally, is there anything in my code that can be optimized to reduce the GPU memory usage?

Thanks in advance. Really appreciate if anyone explain these problems for me.

Below is the training output and my code.

-- beginning
allocated: 281M, max allocated: 281M, cached: 284M, max cached: 284M
-- zero grads
allocated: 281M, max allocated: 281M, cached: 284M, max cached: 284M
-- forward
allocated: 2108M, max allocated: 2154M, cached: 2178M, max cached: 2178M
-- compute loss
allocated: 2566M, max allocated: 2566M, cached: 2635M, max cached: 2635M
-- backward
allocated: 1020M, max allocated: 3482M, cached: 3551M, max cached: 3551M
-- optimizer step
allocated: 1581M, max allocated: 3482M, cached: 3551M, max cached: 3551M
***************** finish train step: 0

-- zero grads
allocated: 1581M, max allocated: 3482M, cached: 3551M, max cached: 3551M
-- forward
allocated: 2955M, max allocated: 3482M, cached: 3552M, max cached: 3552M
-- compute loss
allocated: 3413M, max allocated: 3482M, cached: 3552M, max cached: 3552M
-- backward
allocated: 1581M, max allocated: 4329M, cached: 4468M, max cached: 4468M
-- optimizer step
allocated: 1581M, max allocated: 4329M, cached: 4468M, max cached: 4468M
***************** finish train step: 1

-- zero grads
allocated: 1581M, max allocated: 4329M, cached: 4468M, max cached: 4468M
-- forward
allocated: 2955M, max allocated: 4329M, cached: 4468M, max cached: 4468M
-- compute loss
allocated: 3413M, max allocated: 4329M, cached: 4468M, max cached: 4468M
-- backward
allocated: 1581M, max allocated: 4329M, cached: 4468M, max cached: 4468M
-- optimizer step
allocated: 1581M, max allocated: 4329M, cached: 4468M, max cached: 4468M
***************** finish train step: 2

Please refer to following code as a ready-to-run example of my model and training loop.

import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence as unpack, pack_padded_sequence as pack
from torch.optim.adam import Adam

class Encoder(nn.Module):
    def __init__(self, args):
        self.args = args

        self.rnn = nn.GRU(args.input_size, args.hidden_size, bidirectional=True, batch_first=True)
        self.embedding = nn.Embedding(args.vocab_size, args.input_size, padding_idx=args.pad_id)

    def forward(self, input, h0=None):
            input: long tensor of shape (B x T)
        mask = input.eq(self.args.pad_id)
        total_length = input.size(1)
        # Trim mask when trained with data-parallel
        lens = input.size(1) - mask.sum(1)
        if torch.cuda.device_count() > 1:
            max_len = lens.max().item()
            mask = mask[:, :max_len]
        x = self.embedding(input)
        x = pack(x, lens.tolist(), True)

        x, hx = self.rnn(x, h0)
        x, _ = unpack(x, batch_first=True, total_length=total_length)

        summary = hx[1]  # The last state of backward rnn
        return {
            'feature': x,
            'hidden': summary,
            'mask': mask

class Bridge(nn.Module):
    Bridge the connection between encoder and decoder.

    def __init__(self, args):
        self.map_k = nn.Sequential(nn.Linear(args.hidden_size * 2, args.attention_size),
        self.dec_init = nn.Sequential(nn.Linear(args.hidden_size, args.hidden_size),

    def forward(self, encoder_state):
        K = self.map_k(encoder_state['feature'])
        h0 = self.dec_init(encoder_state['hidden'])
        return {
            'key': K,
            'value': encoder_state['feature'],
            'mask': encoder_state['mask'],
            'hidden': h0,

class Attention(nn.Module):
    def __init__(self, query_size, attention_size):
        super(Attention, self).__init__()
        self.query_size = query_size
        self.key_size = attention_size
        self.map_query = nn.Linear(query_size, attention_size)
        self.v = nn.Linear(attention_size, 1)

    def forward(self, query, keys, values, mask):
        score = v^T * (tanh(W * K + U * q))
            query: B x D
            keys: B x T x D
            values: B x T x D
            mask: B x T
        # B x T x D
        x = keys + self.map_query(query).unsqueeze(1)
        # B x T
        x = self.v(torch.tanh(x)).squeeze(-1)
        x.data.masked_fill_(mask, -float('inf'))
        x = F.softmax(x, -1)
        output = torch.bmm(x.unsqueeze(1), values).squeeze(1)
        return output, x

class Decoder(nn.Module):
    def __init__(self, args):

        input_size = args.input_size
        hidden_size = args.hidden_size
        ctx_size = 2 * hidden_size

        self.embedding = nn.Embedding(args.vocab_size, input_size, args.pad_id)

        self.cell0 = nn.GRUCell(input_size, hidden_size)
        self.attention = Attention(hidden_size, args.attention_size)
        self.cell1 = nn.GRUCell(ctx_size, hidden_size)

        readout = nn.Sequential(
            nn.Linear(input_size + hidden_size + ctx_size, input_size),

        logit = nn.Linear(input_size, args.vocab_size)

        self.generator = nn.Sequential(

        if args.weight_tying:
            self.logit.weight = self.embedding.weight

    def forward(self, inputs, state):
        x = self.embedding(inputs)
        K, V, mask, h = state['key'], state['value'], state['mask'], state['hidden']
        hs = []
        ctxs = []
        for x_i in x.split(1, 1):
            x_i = x_i.squeeze(1)  # B x D
            h = self.cell0(x_i, h)
            ctx, _ = self.attention(h, K, V, mask)
            h = self.cell1(ctx, h)
        hs = torch.stack(hs, 1)
        ctxs = torch.stack(ctxs, 1)
        logit = self.generator(torch.cat([x, hs, ctxs], -1))
        state['hidden'] = h
        return logit, state

class Seq2Seq(nn.Module):
    def __init__(self, args):
        self.args = args
        self.encoder = Encoder(args)
        self.decoder = Decoder(args)
        self.bridge = Bridge(args)

    def encode(self, input):
        state = self.encoder(input)
        return state

    def decode(self, input, state):
        logit, state = self.decoder(input, state)
        return logit, state

    def forward(self, input_x, input_y):
        enc_state = self.encode(input_x)
        dec_state = self.bridge(enc_state)
        del enc_state
        # Reuse eos as bos to left-pad the target.
        input_y = F.pad(input_y, [1, 0], mode='constant', value=self.args.eos_id)
        # Exclude eos from input_y
        input_y = input_y[:, :-1]
        logit, dec_state = self.decode(input_y, dec_state)
        return logit

def stat_cuda(msg):
    print('--', msg)
    print('allocated: %dM, max allocated: %dM, cached: %dM, max cached: %dM' % (
        torch.cuda.memory_allocated() / 1024 / 1024,
        torch.cuda.max_memory_allocated() / 1024 / 1024,
        torch.cuda.memory_cached() / 1024 / 1024,
        torch.cuda.max_memory_cached() / 1024 / 1024

def fake_batch(n_batch, batch_size, max_length=50):
    for _ in range(n_batch):
        x = torch.ones(batch_size, max_length).long().cuda() * 2
        y = torch.ones(batch_size, max_length).long().cuda() * 2
        yield (x, y)

def main(args):
    model = Seq2Seq(args).cuda()
    criterion = nn.CrossEntropyLoss(ignore_index=args.pad_id, reduction='sum').cuda()
    optimizer = Adam(model.parameters())
    for p in model.parameters():
        nn.init.uniform_(p, -0.1, 0.1)

    for i, batch in enumerate(fake_batch(100, args.batch_size)):
        x, y = batch
        batch_size, length = x.size()
        # 1. zero grads
        stat_cuda('zero grads')
        # 2. forward
        logit = model(x, y)
        # 3. compute loss
        logit = logit.view(-1, logit.size(-1))
        loss = criterion(logit, y.view(-1))
        loss = loss / batch_size
        stat_cuda('compute loss')
        # 4. backward
        # 5. optimizer step
        stat_cuda('optimizer step')
        print('***************** finish train step: {}\n'.format(i))

if __name__ == '__main__':
    args = argparse.Namespace()
    args.attention_size = 1000
    args.hidden_size = 1000
    args.input_size = 500
    args.pad_id = 0
    args.vocab_size = 30000
    args.weight_tying = False
    args.batch_size = 80
    args.eos_id = 1

(Alban D) #2


Thanks for the detailed question and measures !

1- when you compute the loss, you allocate memory for the new output Tensors and the intermediary results within the loss function itself. So it is expected to see an increased memory usage during that step. In particular 400M is not too large.
2- The backward does not corresponds to the Adam step. The backward is going though the computational graph and computing gradients for every Tensors that is a leaf (in nn, usually nn.Parameters). The max memory increases here because halfway through the backprop, we had to allocate more buffers for the computed gradients but we could not clear the memory of the unecessary buffers yet. It is expected that the max memory usage is attained during the backward phase. Note as well that at the end of this phase, the allocated memory is greatly reduced again as we have freed all the intermediary buffers. You don’t go all the way down as before the forward call because the gradients are still there (and some internal buffers that some ops have like batchnorm).
3- The extra 500M you see during the first optimizer step is exactly what you computed in question 2: the extra buffers needed by the adam optimizer. model.zero_grad() will zero out the gradients Tensors but not free them. Since you will reuse them, it would be bad performance-wise to free them and re-allocate them everytime. I guess the memory allocation here is: 500M gradients Tensors, 500M adam buffers, 300M others (I guess stats from batchnorm or similar layer).
4- The max memory increases here again because of the allocator most likely. It has some buffer that he uses and cannot reuse exactly the same ones and so have to re-allocate a bit more. This can happen for the first two steps but should be stable after that. As you can see this only changes the cached values and not the actually used one. If you’re familiar with memory fragmentation, that’s what is happening here.
5- That depends why you want to reduce memory usage.

  • If it’s to be able to fit larger batch size or larger net, you don’t need to do anything. Even though more memory is allocated on the device, all the cached but not used memory is still available to put Tensors. So it should not be a problem. If you still OOM, you can use the checkpoint tool to trade memory for compute.
  • If you want to give that memory back to the OS to use it with another program, that’s going to be a bit trickier. There is an empty_cache() function to return all cached but not used memory back to the OS but you should be very careful when using it. It will slow down significantly the next allocations (as we need to re-allocate memory again from cuda) and won’t necessarily reduce the peak memory usage as it may not decrease fragmentation.

Hope this helps, let me know if you have further questions !

(Stas Bekman) #3

@albanD, your explanation was very helpful and timely. Thank you!

Is there anything that can be done to trade off that inefficient setup phase that leads to a huge at times peak memory usage, but not affect the normal training (i.e. not reducing bs, model sizes, etc.)?

For example, MNIST (28x28) w/ bs=512 and resnet50 peaks at 6GB on the first pass, and then goes down to a steady 1GB peak for subsequent epochs - 6 vs 1 is an insanely huge overhead!

Can the computational graph/gradients’ setup be done in stages and then combined to keep the peak memory usage at a smaller multiple?

(Alban D) #4

Does you actually see OOM errors due to this behavior?
The reason I ask this is because the allocator in pytorch is built to be as fast as possible, not to reduce fragmentation. The “trick” that is used to reduce fragmentation is that when no more memory is available. All the un-used blocks as freed (cudaFreed) and then a new block is allocated (cudaMalloced). The cuda driver can reduce fragementation much better than we could by remapping some memory.

It’s possible that if only 2GB were available. Your process will still run just fine with a peak usage of 2 at the first iteration and 1 afterwards.

(Stas Bekman) #5

Wow, this is amazing!

I first run after leaving the suggested 2GB, and then kept reducing the free memory and got an even better number than you suggested (in MBs):

Left free: 8000
epoch used peak
1       80 6220
2        2  914

left free: 1900
epoch used peak
1       80 1856
2        2  916

left free:  900
epoch used peak
1        8  896
2	 2  806

left free:  800
hit OOM

Kernel was restarted for each test. I didn’t run the 3rd epoch, since it always gave the same numbers as the 2nd.

So looking at these numbers it appears that there needs to be as much free memory left as the 2nd epoch peak. Though even the latter got smaller with less free RAM available.

I reduced bs from 512 to 128 and rerun the tests to validate my observation:

left free: 8000
epoch used peak
1       36 3238
2        0  260

left free:  300
epoch used peak
1       46  294
2        0  226

left free: 200
hit OOM

So, it is the 2nd epoch’s memory peak that’s the required minimum!

This is great!

Thank you, @albanD!