I’m listing here a few things that I found mentioned in connection with the issue. For context, I was also training a(n LSTM-) model with AMP + DDP. These, most of which are brought up in this issue, helped to stabilize my model:
- Fairly aggressive clipping of the gradients
- Replacing
softmax
-operations withexp-log-softmax
for numerical stability - Replacing
ReLU
-activations withMish
to prevent dead neurons - Larger epsilon parameter
eps=1e-4
to BatchNorm-layers - Larger epsilon parameter
eps=1e-4
to Adam-optimizer
The instability, however, persisted and the problem was solved by changing the model architecture. More specifically, there was an overflow in one of the BN-layers’ running variance: the fix was to clip the max value of the input tensors before forwarding to the BN-layer, e.g.
...
x = self.relu(x)
x = torch.clamp(x, max=10.)
x = self.bn(x)
...
Since the clamping was done right after the ReLU (later Mish)-activation, it essentially resulted in clipped ReLU.
It turned out that with AMP disabled the problem was there also, but didn’t ever cause the NaNs/Infs to appear.