Hi, I am working on a pytorch model that has several recurrent autoencoders operating on latent data generated by an outer model. I ran into substantial problems regarding stability, where the gradient at some point explodes most of the time resulting in nan weight and loss values (certain model configurations appear to be less prone to this than others for reasons unknown).
To debug this, I’d like to get an idea of how the gradient even explodes, that is, where and how the large gradients are created. E.g., it could be that a few single operations generate very large gradients that accumulate to infinity or whether many moderate gradients in total result in an explosion. Also, but unlikely, it could be that I overlooked some 1/x issue (but I checked this multiple times and guarded everything with epsilons, so unlikely).*
What is the best approach in pytorch to debug this? I thought that it must be possible to extract the actual graph used to compute the derivates for each parameter from autograd somehow. Is this feasible? If not, how could I proceede?
Unfortunately it is difficult to post a minimal working example as this is highly dependent on the data I use which I cannot share and also because it takes a very long time to replicate the gradient explosion issue (in particular due to the recurrent autoencoders which are very slow to train).
edit: I use torch 2.10.0.
*: I process very long sequences (~8000 steps) and work with four recurrent autoencoders. I only see large gradients, however, when I optimize them with respect to the overall output signal of the total model (my model = outer model + inner model; the inner model consists of these four recurrent autoencoders) using the SI-SDR. When I optimize the inner model with respect to the MSE of its immediate input and output signal, no explosion occurs. So it is not just because of the long sequence that the gradient explodes. Of course I use gradient clipping but to no avail.