The scale must be greater than 1.0, otherwise it will increase gradient underflow. However, due to instability during training, the scale may drop below 1.0, causing a death spiral and ultimately crashing the entire training process.
You are right that instabilities during training could cause a “death spiral”.
However, I would rather recommend to fix the actual training instead of trying to limit e.g. the scaling factor as it also won’t save you from the actual failure.
Based on my testing with the model, a high total gradient norm may cause the loss scale to decrease progressively, potentially leading to a death spiral. As a solution, I have manually set the minimum scale to 1.0. Given enough time, the training process tends to stabilize, and the scale usually maintains a normal value of around 32.0. It would be beneficial to allow these developing models an opportunity to adjust their initial parameters.
Additionally, it is recommended that the scaler update function omits batches with NaN losses. My model occasionally produces NaN values due to mathematical operations, although this happens only 0.5 times per 1000 batches. Nonetheless, this could contribute to a death spiral as well.
Again, I would recommend trying to fix the actual NaN outputs or to skip the iteration manually in case these NaN outputs are created as it sounds as if you are trying to change the grad scaling util. to fix an already unstable training.
I would be interested to learn more about the high gradient magnitudes which are not caused by a huge loss or invalid outputs. I.e. in particular which operations seems to blow up the gradients causing the overflow during the scaling.
In my model, I employed torch.linalg.solve and torch.det, which can result in NaN values when the positive-definite condition is violated due to underflow.
The total gradient norm of my model is in the hundreds, and occasionally, it reaches the thousands. However, the loss and outputs seem valid, perhaps because I used torch.nn.utils.clip_grad_norm_, and my model functions well. I am curious to know if it is typical to have such a large total gradient norm.
I don’t fully understand this explanation. Is the gradient norm in the thousands before or after unscaling?
If afterwards, how would a low scaling factor cause an underflow? I also assume you are clipping the gradients after unscaling as described in this amp example?
I don’t know as I haven’t seen it before, but I’m also not familiar with your use case as you are already expecting to see NaN losses sometimes, which is also uncommon.
It’s the gradient norm in the thousands after unscaling.
I presume that in my model, the gradient norm oscillates, frequently soaring high and then plummeting low. The loss scale tends to decrease with a high gradient norm. However, when the gradient norm falls to a low level, the loss scale struggles to promptly recover to a higher value.