Handling nan loss

Hi, pytorch gurus:

I have a training flow that can create a nan loss due to some inf activations, and I already know this is because of noisy dataset yet cleaning up the dataset is hard/disallowed and dataset/dataloader is fixed. Hence, I’m looking for a way to live with nan loss.

One thing I can do is that after backprop is over, I can reset the gradients all zero and continue for the next batch. But, this requires to iterate all the grads, which looks excessive.

Is there a clean way to set loss as zero and ensure all gradients become zero too automatically, rather than resetting after backprop? I tried to reset loss manually zero, but once I do backprop, still grad is abnormal. I can skip the whole backprop, but this can be a problem in DDP env, as some GPUs will still do backprop, causing deadlock (and, I don’t want additional sync up).

Any hint or guidance?

Well when you run optimizer.zero_grad() you are basically iterating over all the parameters. DUnno why you say it’s expensive.

Anyway. Just check if the loss is nan

if not torch.isnan(loss):

And that’s all

I should have been more clear about that. Resetting grad is mechanically not expensive, but what I meant was losing these gradients itself and not getting weight updated in DDP setup. I have N GPUs, each B samples. If there is one nan in loss, then after all-reduce all grads need to be discarded, and N*B samples don’t contribute to the convergence.

Now, I’m thinking of letting nan loss backprop but hijack the Allreduces using hook?

You can simply remove the NaNs at some point inside the model by masking the output. If your loss is elementwise it’s pretty simple to do. If your loss depends on the structure of the tensor (i.e. a matrix multiplication) then replace the NaN by the null element. For example, tensor[torch.isnan(tensor)]=0 or tensor[~torch.isnan(tensor)]

In the last term. most of the losses are an average of all the samples in the batch. N*B in your case.
Official pytorch losses has a flag called reduce or something similar which allows to return the value of the loss for each element of the batch instead of the average.

At that step you can simply remove the NaN element and do a manual average+backprop.

Using hooks seems too complicated to me.

I expected that sort of, but to double confirm: So, I have to replace all Nan in the activations with something else like zero?

Yep but be aware of your problem.
For example, if your gradient is a multiplication, then replacing by 0 does the job. You have to think of the null element (that one which is gonna make the gradient to be 0)
Otherwise just do the simple trick which is removing those NaNs from the N*B loss tensor.
There is a function already to replace nans in pytorch.