RuntimeError: CUDA OOM after training on a few batches

Hi everyone,

I am trying to fine-tune a T5 based model with reinforcement learning. I am currently using the T5 model to paraphrase a relatively short query before feeding it into my environment and retrieving my reward, and then utilize this reward to compute my losses for training.

I have scoured the forums and have tried to incorporate best practices like deleting variables once out of scope and storing my losses with loss.item(). Even more strangely, I am able to train for a couple of batching with ever increasing memory usage as tracked by nvidia-smi. I have even reduced by batch size to 1 but to no avail. I was able to reach the same number of batches trained with a higher batch size of 4 before meeting failure.

I am performing the training on an RTX 2080 ti with 11gb of RAM.

Does anyone know of potential causes of such behaviour? Thank you for taking the time to read this post.

Edit (include code):
Main Training Loop

for epoch in range(epochs):
        acc_loss = 0.0
        counter = 0
        for i, batch in enumerate(train_dataloader):
            queries = batch[
                "question"
            ]  # pass into step if using relative reward
            contexts = batch["context"]
            answers = batch["answer"]

            del batch
            output = agent.forward(queries)
            del queries

            reformulations = agent.decode_batch_sequence(
                output.sequences.detach().cpu(), skip_special_tokens=True
            )

            rewards = world_model.step(reformulations, contexts, answers)

            del contexts, answers, reformulations

            rewards = torch.tensor(rewards, dtype=torch.float16).to(
                agent.device
            )
            gen_probs, probs = agent.probabilities(output)
            loss = agent.policy_loss(rewards, gen_probs, probs)

            agent.optim.zero_grad()
            loss.backward()
            agent.optim.step()

            loss = loss.detach().item()
            wandb.log({"policy_loss": loss})

            acc_loss += float(loss)
            del loss
            counter += 1

        epoch_loss = acc_loss / counter
        wandb.log({"epoch_loss": epoch_loss})

        if epoch_loss < best_loss:
            best_loss = epoch_loss
            agent.transformer.save_pretrained(args.save_path)

Error:

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB already allocated; 2.44 MiB free; 9.68 GiB reserved in total by PyTorch)

Further Updates:
After using the debug_memory code snippet found here, it seems like my tensors are not being cleared after the optim.step() call, but rather freed when the computation of the loss for the next batch starts.

The intermediate forward activations should have been already freed during the backward() call, however the gradients would be set to zero in optimizer.zero_grad() and not removed. To save more memory, pass set_to_none=True to zero_grad, which would lower the memory footprint a bit more.
Also, based on your description it seems that the OOM issue is raised in a specific iteration regardless of the batch size, which could indicate a larger memory requirement for this iteration. Are you using variable input shapes and could this particular batch be larger?

Hi @ptrblck ,

Thank you for taking the time to read my post and reply! You are right that a particular batch was larger than others and it caused the program to crash there. To resolve that I curbed the max length of my generated sequences.

However, when I tracked the tensors in my program with the debug_memory snippet, it does seem like the intermediate tensors are still kept until I compute the loss for the next batch, which I believe is strange?

But with regards to the memory issue, you are entirely spot on! Thank you for your help!