pytorch fix gpu mem leak after exactly 10 minibatches

So I had this problem for a few days that is driving me crazy. I have a seq2seq model (specifically a Listen, Attend and Spell model from google) that Im training and always goes crazy wild in gpu memory usage when reaching mini batch number 10. I have tried shuffling the dataloader to make sure the problem wasnt related to the specific data being retrieved in the 10th time but the error persisted to happen in this 10th iteration.

My train loop looks as follows:

for epoch in range(start_epoch, epochs):
    epoch_step = 0
    train_loss = []
    train_ler = []
    batch_loss = 0
    for i, (data) in enumerate(train_loader):
        print(
            f"Current Epoch: {epoch} Loss {np.round(batch_loss, 3)} | Epoch step: {epoch_step}/{len(train_loader)}",
            end="\r",
            flush=True,
        )
        # Adjust LR
        tf_rate = tf_rate_upperbound - (tf_rate_upperbound - tf_rate_lowerbound) * min(
            (float(global_step) / tf_decay_step), 1
        )
        with torch.no_grad():
            inputs = data[1]["inputs"].cuda()
            labels = data[2]["targets"].cuda()

        # minibatch execution
        batch_loss, batch_ler = batch_iterator(
            batch_data=inputs,
            batch_label=labels,
            las_model=las,
            optimizer=optimizer,
            tf_rate=tf_rate,
            is_training=True,
            max_label_len=params["data"]["vocab_size"],
            label_smoothing=params["training"]["label_smoothing"],
        )
        del inputs
        del labels
        torch.cuda.empty_cache()

        train_loss.append(batch_loss)
        train_ler.extend(batch_ler)

        global_step += 1
        epoch_step += 1

Where batch_iterator trains the network from the minibatch. Batch_iterator looks as follows:

def batch_iterator(
    batch_data,
    batch_label,
    las_model,
    optimizer,
    tf_rate,
    is_training,
    max_label_len,
    label_smoothing,
    use_gpu=True,
):
    label_smoothing = label_smoothing
    max_label_len = min([batch_label.size()[1], max_label_len])
    criterion = nn.NLLLoss(ignore_index=0).cuda()
    optimizer.zero_grad()
    raw_pred_seq, _ = las_model(
        batch_data=batch_data,
        batch_label=batch_label,
        teacher_force_rate=tf_rate,
        is_training=is_training,
    )
    pred_y = (
        torch.cat([torch.unsqueeze(each_y, 1) for each_y in raw_pred_seq], 1)[:, :max_label_len, :]
    ).contiguous()

    if label_smoothing == 0.0 or not (is_training):
        pred_y = pred_y.permute(0, 2, 1)  # pred_y.contiguous().view(-1,output_class_dim)
        true_y = torch.max(batch_label, dim=2)[1][:, :max_label_len].contiguous()  # .view(-1)

        loss = criterion(pred_y, true_y)
        # variable -> numpy before sending into LER calculator
        batch_ler = LetterErrorRate(
            torch.max(pred_y.permute(0, 2, 1), dim=2)[1]
            .cpu()
            .numpy(),  # .reshape(current_batch_size,max_label_len),
            true_y.cpu().data.numpy(),
        )  # .reshape(current_batch_size,max_label_len), data)

    else:
        true_y = batch_label[:, :max_label_len, :].contiguous()
        true_y = true_y.type(torch.cuda.FloatTensor) if use_gpu else true_y.type(torch.FloatTensor)
        loss = label_smoothing_loss(pred_y, true_y, label_smoothing=label_smoothing)
        batch_ler = LetterErrorRate(
            torch.max(pred_y, dim=2)[1].cpu().numpy(),  # .reshape(current_batch_size,max_label_len),
            torch.max(true_y, dim=2)[1].cpu().data.numpy(),
        )  # .reshape(current_batch_size,max_label_len), data)

    if is_training:
        loss.backward()
        optimizer.step()

    batch_loss = loss.cpu().data.numpy()

    return batch_loss, batch_ler

The full code of the project is inside this github repository https://github.com/jiwidi/las-pytorch if want to check it:

While running in a v100 (16gb of gpu ram) the usage of memory is constant around 8gb and spikes to 14gb in 10th iteration. As the error persists to happen in this exact iteration regardless that I shuffle the dataloader I believe the error is in my code.

I also tried to put print statmentes before each batch_iterator where I would print the size of both inputs and labels, the size of them were equal to the ones in the previous minibatches. Output here:

For minibatch 0 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 1 inputs has size 2.08896mb and labels has size 0.706944mb
For minibatch 2 inputs has size 2.00704mb and labels has size 0.706944mb
For minibatch 3 inputs has size 2.12992mb and labels has size 0.706944mb
For minibatch 4 inputs has size 2.17088mb and labels has size 0.706944mb
For minibatch 5 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 6 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 7 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 8 inputs has size 1.92512mb and labels has size 0.706944mb
For minibatch 9 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 10 inputs has size 1.96608mb and labels has size 0.706944mb

The the actual error pytorch gives me:

raceback (most recent call last):
  File "train.py", line 138, in <module>
    label_smoothing=params["training"]["label_smoothing"],
  File "/home/fhjaime966/las-pytorch/solver/solver.py", line 63, in batch_iterator
    is_training=is_training,
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/fhjaime966/las-pytorch/model/las_model.py", line 34, in forward
    listener_feature, ground_truth=batch_label, teacher_force_rate=teacher_force_rate
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/fhjaime966/las-pytorch/model/las_model.py", line 212, in forward
    rnn_input, hidden_state, listener_feature
  File "/home/fhjaime966/las-pytorch/model/las_model.py", line 180, in forward_step
    attention_score, context = self.attention(rnn_output, listener_feature)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/fhjaime966/las-pytorch/model/las_model.py", line 296, in forward
    * attention_score[0].unsqueeze(2).repeat(1, 1, listener_feature.size(2)),
RuntimeError: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 15.75 GiB total capacity; 14.69 GiB already allocated; 6.88 MiB free; 14.69 GiB reserved in total by PyTorch)

I know this is quite a big problem to debug as you would need to check the code at the repository https://github.com/jiwidi/las-pytorch to get an understanding of what is running so any help or tips with debugging gpu mem usage in pytorch are greatly appreciated.

Could you print(torch.cuda.memory_summary()) in each iteration and check, if the memory is growing?

Hi!

Sorry for the late reply. My error is that I was using a too high max_sample_length for my attention mechanism and that was causing a spike on the 10th iteration (decided by a random int, but since I had a seed this was happening always on 10th) to allocate memory out of what was really needed. Once I solved that it works as expected now