Gradient becomes nan with random RuntimeError: Function 'ExpBackward0' returned nan values in its 0th output

I am training a classification model, and getting nan loss values after some time. I used torch.autograd.set_detect_anomaly(True) to trace the error and I get this as output:

RuntimeError: Function 'ExpBackward0' returned nan values in its 0th output.

I then used the following checks for values of variables and gradients in my training routine to find out where the problem arises first:

batch_predictions = model(batch_inputs)
batch_loss = loss(batch_predictions, batch_targets)
if torch.isfinite(batch_predictions).all() and torch.isfinite(batch_loss.item()):
    print("outputs and loss ok")
else:
    print("outputs and loss not ok")
batch_loss.backward()
grads = [p.grad for p in model.parameters() if p.requires_grad is True]
for grads_ in grads:
    if torch.isnan(grads_).any():
        print("grads nan")
    else:
        print("grads ok")
optimizer.step()

After some time I get grads nan as output, but the immediately preceding outputs and loss ok is also printed. This means that the outputs are ok, the loss is ok but the gradient calculations with batch_loss.backward() leads to nan gradients being calculated. I have tried changing the optimizer and reducing the learning rate, but nothing works. I am not sure why this is happening, and how to probe further and correct this. Thanks in advance for any help.

Hi,
Non nan losses and nan gradients are mostly a result of some absurd (undefined) mathematical operation like 0⁰, dividing by 0 and so on.

The hint provided by anomaly detection probably hints at the step in the computational graph where such an operation is occurring leading to nan gradients.

I would also think limited precision can be a potential reason but less likely as compared to the former.

Srishti is quite right. I would also note that runtimeerror is due to ExpBackward, meaning backprop of exponential function. This implies somewhere in the network exponential function is used.
Since exponential function blows too fast (exponentially fast :)) → overflow
or can return too small value → underflow. I would take a look at this part of the network.

Maybe, apply some cutofff to output of exponential function, just an idea.