Hello, I am training a CLIP (ViT) model.
I would like to skip a mini-batch when I find NANs in the output of the model.
So this is a simplified pseudo code example.
for data in dataloaders:
output = model(data[0])
if torch.isnan(output):
continue
loss = loss_fn(output, data[1])
loss.backward()
optimizer.zero_grad()
loss.backward()
optimizer.step()
However, when output contains NAN and if I skip the mini-batch, it faces GPU lack of memory error.
I suspect that something is not properly freed If I skip the mini-batch.
Could I get any answer?
Thanks!