GPU memory OOM error when I skip a batch

Hello, I am training a CLIP (ViT) model.

I would like to skip a mini-batch when I find NANs in the output of the model.

So this is a simplified pseudo code example.

for data in dataloaders:
output = model(data[0])
if torch.isnan(output):
continue
loss = loss_fn(output, data[1])

  loss.backward()
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

However, when output contains NAN and if I skip the mini-batch, it faces GPU lack of memory error.
I suspect that something is not properly freed If I skip the mini-batch.
Could I get any answer?
Thanks!

Yes I think you are right and since you are not calling backward() if the batch is skipped, the forward activations would be kept alive until the reference is deleted.
You could try to del output in this case and see, if this would free the intermediates.

1 Like

Thanks!
The solution was really simple indeed!
Just del output solved everything!