Same issue.But I think I’ve resolve it. When we use loss function like ,Focal Loss or Cross Entropy which have log() , some dimensions of input tensor may be a very small number. It’s a number bigger than zero , when dtype = float32. But amp will make the dtype change to float32. If we check these dimensions , we will find they are [0.]. So as the input of log(), we will get NaN. There are two ways to solve the promblem:
- add a small number in log ,like 1e-3. The price is the loss of precision
- make the dypte of the input of log() be float32
e.g.:yhat
= torch.sigmoid(input).type(torch.float32)
loss = -y*((1-yhat) ** self.gamma) * torch.log(yhat + 1e-20) - (1-y) * (yhat ** self.gamma) * torch.log(1-yhat + 1e-20)`