Colab runs out of memory when training

I am trying to train a CNN using PyTorch in Google Colab, however after around 170 batches Colab freezes because all available RAM (12.5 GB) is used. I’ve looked up online and it seems that other people with this same issue often unintentionally store the computation graph (e.g. of the loss). Checking my code this doesn’t seem the issue, so I was hoping if someone can check if this indeed the case and what the problem is in my code. I am using the following training script with a custom dataloader:

import torch
from torch.utils.data import DataLoader
from torch import optim, nn
from torchsummary import summary
from dataset import AudioDataset
from model import ConvNet
import os
import time


def train(x, y, model, optimizer, loss_function, lr_scheduler=None):
    model.train()
    y_pred = model(x)
    loss = loss_function(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if lr_scheduler:
        lr_scheduler.step(optimizer)
    return loss.item()


def validate(x, y, model, loss_function):
    model.eval()
    y_pred = model(x)
    loss = loss_function(y_pred, y)
    return loss.item()


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    config = {
        "learning_rate": 0.02,
        "batch_size": 32,
        "n_epochs": 30,
    }

    train_loader = DataLoader(
        AudioDataset(os.path.join("audio", "train")), batch_size=config["batch_size"], shuffle=True
    )
    val_loader = DataLoader(
        AudioDataset(os.path.join("audio", "test")), batch_size=config["batch_size"], shuffle=True
    )

    model = ConvNet()
    summary(model, input_size=(1, 173, 64))
    loss_fn = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=config["learning_rate"])
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
    model.to(device)

    # Training loop
    for epoch in range(config["n_epochs"]):
        print(f"Epoch {epoch+1}/{config['n_epochs']}")
        start_time = time.time()
        train_loss = 0
        train_accuracy = 0
        val_loss = 0
        val_accuracy = 0
        for index, batch in enumerate(train_loader):
            if (index + 1) % 10 == 0:
                print(f"Batch {index + 1}/{len(train_loader)}")
            x_train, y_train = batch
            x_train = x_train.to(device)
            y_train = y_train.to(device)
            train_loss += train(x_train, y_train.float(), model, optimizer, loss_fn)

        for index, batch in enumerate(val_loader):
            x_val, y_val = batch
            x_val = x_val.to(device)
            y_val = y_val.to(device)
            val_loss += validate(x_val, y_val.float(), model, loss_fn)

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        total_time = time.time() - start_time
        print(f"{total_time} - loss: {train_loss} - val_loss: {val_loss}")

The amount of memory used seem to suddenly spike and keep increasing, see this graph:

Is this the RAM or the GPU memory?
Skimming through the code I cannot find anything obviously wrong.
Could you map the seconds to epochs or iterations? What happens after 300 seconds?
Are you starting the validate the model or does a new epoch start?

This is the memory bar in Google Colab, so I think it is RAM but I’m not 100% sure. This happens independent on whether I use the CPU or GPU. The graph is completely from 1 epoch, at around 300 seconds the network has trained for about 120 batches. After 170 batches (400 seconds) all memory is used and Google Colab crashes. If it is helpful I could also show the code for my custom dataloader to see if that is causing the problem.

If I understand the explanation correctly, the code works fine for 120 epochs and suddenly increases the memory until you run into an OOM?

That is almost correct, instead of epochs it is batches (my mistake). So the memory stays the same for the first 120 batches in the first epoch after which the memory suddenly starts increasing up to 100% (12.5GB), after which Colab crashes. My total dataset is around 6GB in size, so I don’t think it is saving that into memory. Like I said earlier, I don’t know if it helps to show my custom dataloader but if it does be sure to let me know and I will post it.

After investigating further I found the following topic, which led me to checking my dataloader, after which I found out that the error is indeed caused by my dataloader. Since I am using audio data I am making use of the torchaudio package. In my dataloader I load the audio, resample it and create a mel spectrogram on which I then train my model. Singling out the different preprocessing steps in my dataloader is caused by the creation of the mel spectrogram (torchaudio.transforms.MelSpectrogram). Since torchaudio is calling torchaudio.transforms.Spectrogram first and then creating the mel bins, I also tried using the spectrogram function. This continued normally for longer but still gave me the same error. See the following minimum working example:

import torch
import torchaudio

for epoch in range(10):
  print(f"Epoch {epoch+1}/{30}")
  for example in range(20000):
    if (example + 1) % 100 == 0:
      print(f"Example {example + 1}")
    x = torch.rand(1, 88200)
    mel_spec = torchaudio.transforms.MelSpectrogram()(x)

I am not sure if/how I can fix this or if the fix from the aforementioned topic can be applied to my case.

@ptrblck Do you happen to have any idea for what causes the problem/how I circumvent it?

After checking the memory usage after each mel spectrogram transform it seems that every example is adding 1-2MB to the total RAM used (for MelSpectrogram, Spectrogram seems to use around half of that), still haven’t got a clue why it is happening. It seems that using the functional transformation (torchaudio.functional.spectrogram in the case of torchaudio.transforms.Spectrogram, which is called in the Spectrogram class for the actual transformation) does not add to the amount of RAM used, which makes it even weirder for me. If anyone has any idea I’d really appreciate any thoughts/ideas.

Same issues happened