I am working on an architecture where I experience spurious exploding gradients and I want to find out which operation exactly is causing them. I have already identified the parameters that are affected by these huge gradients and have code that identifies when unusual gradients occur, but I am unsure how I can proceed.
I think I know what causes it for some parameters, but others I have no clue.
Ideally, I would recalculate the gradient while retaining the graph and interactively traverse the gradient-calculation backwards to find out what’s going on.
When i recalculate my gradients using .backward(retain_graph=True), the grad_fn for the parameters is still None and I am not sure how to actually interact with the graph and find the reason for the exploding gradient.
EDIT: Ok, I’ve found out that by using create_graph=True the grad_fn get’s populated, but I am unsure whether it’s possible to interact with them in a meaningful way. I have not found a way to evaluate them or get the context of them.
What kind of code are you currently using to “identify when unusual gradients occur”?
You could register hooks on all parameters and maybe print some debug message, when high gradients are propagated.
Once the step is found, you could check all .grad attributes as well as the parameters to see, which operations causes the high gradients?
Maybe I’ve misunderstood your question and you are already doing exactly this.
@ptrblck yeah, I am doing something similar to this. I am keeping track of whether the gradient-norm is unexpected based on previous gradients and some absolute conditions. I am also generating plots containing the maximum and mean norm of the gradients per parameter. So I have already identified some parameters that sometimes (randomly) get unusually high gradients (they are very, very big. in the range of 1e+6).
I suspect that certain parametrizations of the operations I use lead to exploding gradients further down the networks since it occurs without a pattern and relatively unaffected by other hyperparameters. The immediate operations are fairly standard (for example, some of the affected modules are standard 2d-convolution).
But I am not sure, some of the operations are very dynamic in nature and some get predicted by another network. It’s quite delicate and I have already thought of exploding gradients under certain scenarios, but it seems like I have not eliminated every scenario.
So I have the code to start a debugging session, the code to identify the affected parameters and the code to move my model and the data onto the CPU to enjoy a lot more available RAM and potentially trace the operations. But I am unsure how to start from here so that I can identify the culprits that let my gradients explode. They must originate from somewhere.
Sometimes it happens after a few minutes, sometimes I need to wait a few hours.
Ideally, I would have an interactive conversation with the autograd-framework, where I would ask it questions about the results and the responsible computations and parameters to isolate and identify the origin. But I am fine with a more crude way as long as i don’t have to guess. There are a lot of moving parts.
@Prerna_Dhareshwar gradient clipping does not solve my problem since the thing i am interested in is the source of the instability. It’s more of a duct-tape approach. It’s researchy, but I don’t see an immediate reason why the model must be unstable.I think it’s just a special condition I overlooked, but I am not sure how I am supposed to interact with the autograd-framework in this scenario.