How to go about finding cause of loss spike

I’m curious to hear strategies about how to approach this problem which I think is a common one:

How do you go about figuring out what part of the model or loss function is causing a rare loss spike?

Pytorch has detect_anomaly , but that slows down training massively. One could add isfinite() checks to the forward pass, but that won’t detect stuff going wrong in the backward pass. One could also add hooks for the backward pass, but that’s a lot of effort.

What other approaches might be possible?

1 Like

You can make use of gradient clipping.

https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html

If it’s an RL problem, you might also try clipping the objective function.

https://spinningup.openai.com/en/latest/algorithms/ppo.html

This reply does not answer the question of how to find the cause.

1 Like

As stated in direct message, absent gradient clipping, gradients get smaller the further up the model you go in the backward pass. It’s a consequence of how calculating the derivatives work.

When they get small enough, certain optimizers, such as RMSProp and Adam, which certain calculations from the gradients in the denominator, can cause the step change to jump when the gradient change approaches zero. And so it is likely that your first layer in the model is what produces the issue.

If you want to see what is what is occurring, you can print the average of the gradients of each layer.

#after calculating loss and running loss.backward() and optimizer.step()
for param in model.parameters():
    print(torch.mean(param.grad))

Also see here:

I was hoping to start a discussion. Thank you for making me realize this is not the right place.

1 Like