Nan Loss with torch.cuda.amp and CrossEntropyLoss

Same issue.But I think I’ve resolve it. When we use loss function like ,Focal Loss or Cross Entropy which have log() , some dimensions of input tensor may be a very small number. It’s a number bigger than zero , when dtype = float32. But amp will make the dtype change to float32. If we check these dimensions , we will find they are [0.]. So as the input of log(), we will get NaN. There are two ways to solve the promblem:

  1. add a small number in log ,like 1e-3. The price is the loss of precision
  2. make the dypte of the input of log() be float32
    e.g.: yhat = torch.sigmoid(input).type(torch.float32)
    loss = -y*((1-yhat) ** self.gamma) * torch.log(yhat + 1e-20) - (1-y) * (yhat ** self.gamma) * torch.log(1-yhat + 1e-20)`
4 Likes