CUDA memory leak while training

Hi,

I ran into a problem with CUDA memory leak. I’m training on a single GPU with 16GB of RAM and I keep running out of memory after some number of steps. Around 500 out of 4000.
My dataset is quite big, and it crashes during the first epoch.
I noticed that memory usage is growing steadily, but I can’t figure out why.

At first, I wasn’t forcing CUDA cache clear and thought that this might do the trick, but even after torch.cuda.empty_cache() I keep seeing the same issue:

I’m using Pytorch Lightning, but I don’t think that it’s because of it, but rather because of my model.
Can anyone see the issue or point me in the right direction - how to find a memory leak?

Below is my full model:

from collections import OrderedDict

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from preprocessing import GreedyDecoder, cer, wer


class SpeechRecognitionModel(pl.LightningModule):

    def __init__(self, hparams: dict):
        super(SpeechRecognitionModel, self).__init__()
        self.hparams = hparams
        self.criterion = nn.CTCLoss(blank=37)

        n_cnn_layers = hparams['n_cnn_layers']
        n_rnn_layers = hparams['n_rnn_layers']
        rnn_dim = hparams['rnn_dim']
        n_class = hparams['n_class']
        n_feats = hparams['n_feats']
        stride = hparams['stride']
        dropout = hparams['dropout']

        n_feats = n_feats // 2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i == 0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim * 2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2)  # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), self.hparams['learning_rate'])
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.hparams['learning_rate'],
            steps_per_epoch=int(len(self.train_dataloader()) // self.hparams['batch_size']),
            epochs=self.hparams['epochs'],
            anneal_strategy='linear'
        )

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_nb):
        spectrograms, labels, input_lengths, label_lengths = batch

        output = self.forward(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1)  # (time, batch, n_class)

        loss = self.criterion(output, labels, input_lengths, label_lengths)

        torch.cuda.empty_cache()

        comet_logs = {'training_loss': loss}
        return {'loss': loss, 'log': comet_logs}

    def validation_step(self, batch, batch_nb):
        spectrograms, labels, input_lengths, label_lengths = batch
        output = self(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1)  # (time, batch, n_class)

        loss_val = self.criterion(output, labels, input_lengths, label_lengths)
        val_cer, val_wer = [], []
        decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
        for j in range(len(decoded_preds)):
            val_cer.append(cer(decoded_targets[j], decoded_preds[j]))
            val_wer.append(wer(decoded_targets[j], decoded_preds[j]))

        avg_cer = sum(val_cer) / len(val_cer)
        avg_wer = sum(val_wer) / len(val_wer)

        output = OrderedDict({
            'val_loss': loss_val,
            "avg_cer": torch.tensor(avg_cer),
            "avg_wer": torch.tensor(avg_wer),
        })

        return output

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        val_cer = [x['avg_cer'] for x in outputs]
        val_wer = [x['avg_wer'] for x in outputs]
        avg_cer = sum(val_cer) / len(val_cer)
        avg_wer = sum(val_wer) / len(val_wer)

        comet_logs = OrderedDict({
            'val_loss': avg_loss,
            "avg_cer": avg_cer,
            "avg_wer": avg_wer}

        )
        return {'val_loss': avg_loss, 'log': comet_logs}


class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""

    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous()  # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous()  # (batch, channel, feature, time)


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
        except with layer norm instead of batch norm
    """

    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x  # (batch, channel, feature, time)


class BidirectionalGRU(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x
2 Likes

If you store tensors, which are attached to the computation graph, you will also store the whole graph.
Usually this increases the memory so that you would hit an OOM issue after a while.

I don’t know, how these values are used in your code, but here are some potential condidates:

# in training_step
comet_logs = {'training_loss': loss}
return {'loss': loss, 'log': comet_logs}

# in validation step
output = OrderedDict({
    'val_loss': loss_val,
    "avg_cer": torch.tensor(avg_cer),
    "avg_wer": torch.tensor(avg_wer),
})

# in validation_epoch_end
comet_logs = OrderedDict({
    'val_loss': avg_loss,
    "avg_cer": avg_cer,
    "avg_wer": avg_wer}
)

If you are storing these dicts outside of your training and validation methods, this might explain the OOM. You could call .detach() on all these tensors, if you don’t need to call backward on them anymore to avoid this issue.

Also, you could wrap your validation loop in a with torch.no_grad() bock to save further memory.

Thanks for taking a look into the issue.

I’m not storing these tensors manually and I can’t detach them in the training step as PyTorch Lightning needs it for backward. I tried detaching, using loss.item() in comet_logs = {}. Nothing worked so far.

I’ll look around in PyTorch Lightning code for details what they are doing with these dicts.
But any ideas would be helpful.

Ah OK, thanks for the information.
I assume the validation loss could at least be detached.

Could you also post another figure showing the increase in memory usage?
The first posted image (with torch.cuda.empty_cache()) doesn’t show a clear increase.

Here’s the memory usage without torch.cuda.empty_cache()

It doesn’t say much.
I also set up memory profiling found in this topic How to debug causes of GPU memory leaks? - #18 by Jean_Da_Rolt

Memory usage fluctuates a bit but stays around 12800Mb after step ~220

What I noticed is that it ALWAYS crashes on step 507/3957. Not sure if that indicates anything at all.

RuntimeError: CUDA out of memory. Tried to allocate 1.14 GiB (GPU 0; 14.76 GiB total capacity; 12.09 GiB already allocated; 483.44 MiB free; 13.48 GiB reserved in total by PyTorch)

I also tried to remove comet logger and all logging, but result was the same.
I could try to remove PyTorch Lightning and do training manually, I don’t think that it would help.

It seems that the memory usage is pretty constant for some iterations. The initial peak usage might be caused by creating some temporal objects, which might have been deleted later.

Are you using some kind of variable shape in your model, which could increase the memory after 507 steps or is this particular iteration somehow “special”?

No. Whole model is what you see in the first post.
Preprocessing is identical for all steps: load audio and transcript, convert audio to MelSpectrogram, convert transcript to tensor.
Not even sure where to look next :slight_smile:

Could you post a small script to reproduce this issue (or to make the code executable)?
I.e. we would need the model arguments to initialize the model as well as the shapes and types of all input tensors.

I can provide all scripts or access to GH repo, but won’t you need data to replicate the problem?

I could create random tensors, if you provide the shapes.
This would at least allow to debug the current model and training code.

spectrograms, labels, input_lengths, label_lengths = batch

spectrograms: torch.Size([8, 1, 128, 644])
labels: torch.Size([8, 108])
input_lengths: 8
label_lengths: 8
    hparams = {
        "n_cnn_layers": 3,
        "n_rnn_layers": 5,
        "rnn_dim": 512,
        "n_class": 38,
        "n_feats": 128,
        "stride": 2,
        "dropout": 0.1,
        "learning_rate": 0.002,
        "batch_size": 8,
        "epochs": 3
    }

Thanks for the shapes.
Unfortunately, the imports from preprocessing are not defined, so that I cannot run it.
Could you post these classes/methods as well as the criterion, please?

I added PyTorch Lightning implementation based on this Colab Notebook.

I also tested by removing all PyTorch Lightning implementation and used only PyTorch - memory leak was still present. Although on my dataset it managed to get through more than 85% of the first epoch.

Train Epoch: 1 [43200/50639 (85%)]	Loss: 2.600413
Traceback (most recent call last):
  File "entrypoint.py", line 169, in <module>
    experiment=experiment
  File "entrypoint.py", line 148, in main
    train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment)
  File "entrypoint.py", line 46, in train
    loss.backward()
  File "/home/project/.venv/lib/python3.6/site-packages/comet_ml/monkey_patching.py", line 292, in wrapper
    return_value = original(*args, **kwargs)
  File "/home/project/.venv/lib/python3.6/site-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/project/.venv/lib/python3.6/site-packages/torch/autograd/__init__.py", line 100, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 1.03 GiB (GPU 0; 14.76 GiB total capacity; 11.60 GiB already allocated; 709.44 MiB free; 13.26 GiB reserved in total by PyTorch) (malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:289)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x46 (0x7f4139f78536 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1cf1e (0x7f413a1c1f1e in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x1df9e (0x7f413a1c2f9e in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #3: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0x135 (0x7f413cd56535 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xf7a66b (0x7f413b34e66b in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0xfc3f57 (0x7f413b397f57 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0x1075389 (0x7f41778d2389 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0x10756c7 (0x7f41778d26c7 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xf30f5e (0x7f413b304f5e in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #9: <unknown function> + 0xf3ccc7 (0x7f413b310cc7 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #10: at::native::_cudnn_rnn_backward(at::Tensor const&, c10::ArrayRef<at::Tensor>, long, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, long, bool, double, bool, bool, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, std::array<bool, 4ul>) + 0x1a9 (0x7f413b313319 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #11: <unknown function> + 0xfc20ad (0x7f413b3960ad in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #12: <unknown function> + 0xfc3843 (0x7f413b397843 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #13: <unknown function> + 0x2b08450 (0x7f4179365450 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x2b7b8a3 (0x7f41793d88a3 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #15: torch::autograd::generated::CudnnRnnBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x708 (0x7f4179119d28 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0x2d89c05 (0x7f41795e6c05 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #17: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x16f3 (0x7f41795e3f03 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #18: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x7f41795e4ce2 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #19: torch::autograd::Engine::thread_init(int) + 0x39 (0x7f41795dd359 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x7f4185d1c4d8 in /home/project/.venv/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #21: <unknown function> + 0xbd6df (0x7f4199fd66df in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #22: <unknown function> + 0x76db (0x7f41a08656db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #23: clone + 0x3f (0x7f41a0b9e88f in /lib/x86_64-linux-gnu/libc.so.6)

Tried changing various parameters:

  • reduced n_rnn_layers to 4, 3, 2
  • reduced rnn_dim to 256

it seems that nothing helped so far.

I’ve used the same model to train a network and I haven’t experienced OOMs you’ve mentioned. However, I’ve used my private dataset + custom environment + Ignite.

This would probably imply that either some specific environment setup (PyTorch, CUDA versions) is causing the issue or either the original script has a leak somewhere that we can’t see yet.

1 Like

Took a closer look at my dataset. It had audio files which produced very long inputs which were the cause of OOM.

Lessons learned:

  • Check your data. Even if dataset description says, that it’s clean, prepared for training.
  • Audio sample rate matters - it should 8000 or 16000 (in my case). I had varying from 22.4kHz to 44.1kHz

Probably you’ve already noticed, but there is “Resample” transformer to fix that

1 Like

It was faster with a script and sox to convert them all at once. Resample would add additional overhead each time. Maybe it could be used as an augmentation technique?

It’s very unlikely that there is any benefit in doing that. But if the network manages to consume different sample rates with a few extra convolution layers, that would maybe improve its generalization capabilities?

But in my experience, it’s hard to make this network converge as it is.

With 8000 sample rate it failed miserably almost instantly.
With 16000 it trains. Unfortunately training is very slow as I have only 1 GPU at the moment (GCP refuses to increase the quota).