How to inspect whethet there is NaN or Inf in gradients after amp?

Hi, guys,
I want to know how to inspect whethet there is NaN or Inf in gradients after amp?

optimizer.zero_grad()
with autocast():
    output = model(input)
    loss = loss_fn(output, target)
scaler.scale(loss).backward()
# To Insepct Whethet There is Infs or NaNs in Gradients After AMP?
if contain_inf_nan():
    # Do not skip if there is infs or NaNs
    raise RuntimeError
scaler.step(optimizer)
scaler.update() 

Your answer or guide will be appreciated!

1 Like

You can inspect the .grad attributes of each parameter (as is also done in the scaler.step() to avoid updating the parameters if invalid gradients were found e.g. due to a too large scaling factor) e.g. via:

if not all([torch.isfinite(p.grad) for p in model.parameters()]):
    print("invalid gradients")
2 Likes

Thank you sincerely!

1 Like

Quick follow-up in case it was missed: note that the scaler.step(optimizer) will already check for invalid gradients and if these are found then the internal optimizer.step() call will be skipped and the scaler.update() operation will decrease the scaling factor to avoid overflows in the next training iteration.
If you are skipping these steps manually, you might get stuck with the same (high) scaling factor and might continue to run into overflows and thus invalid gradients, so keep this in mind for your use case.

1 Like

Appreciate for your guide. I have noticed it by using apex.