Nan Loss with torch.cuda.amp and CrossEntropyLoss

I’m listing here a few things that I found mentioned in connection with the issue. For context, I was also training a(n LSTM-) model with AMP + DDP. These, most of which are brought up in this issue, helped to stabilize my model:

The instability, however, persisted and the problem was solved by changing the model architecture. More specifically, there was an overflow in one of the BN-layers’ running variance: the fix was to clip the max value of the input tensors before forwarding to the BN-layer, e.g.

...
x = self.relu(x)
x = torch.clamp(x, max=10.)
x = self.bn(x)
...

Since the clamping was done right after the ReLU (later Mish)-activation, it essentially resulted in clipped ReLU.

It turned out that with AMP disabled the problem was there also, but didn’t ever cause the NaNs/Infs to appear.

1 Like