Loss and dice metric becomes nan after epoch 125

epoch: 125/200, subject: 1/2, batch: 26/32, avg-batch-loss: 0.2313, avg-batch-dice: 0.6728
epoch: 125/200, subject: 1/2, batch: 27/32, avg-batch-loss: 0.2289, avg-batch-dice: 0.6704
epoch: 125/200, subject: 1/2, batch: 28/32, avg-batch-loss: 0.2233, avg-batch-dice: 0.6989
epoch: 125/200, subject: 1/2, batch: 29/32, avg-batch-loss: 0.2275, avg-batch-dice: 0.6986
epoch: 125/200, subject: 1/2, batch: 30/32, avg-batch-loss: 0.2232, avg-batch-dice: 0.7028
epoch: 125/200, subject: 1/2, batch: 31/32, avg-batch-loss: 0.2241, avg-batch-dice: 0.7013
epoch: 125/200, subject: 1/2, batch: 32/32, avg-batch-loss: 0.2176, avg-batch-dice: 0.7159
Criteria at the end of epoch 125 subject 1 is 0.7159
Criteria increased from 0.6941 to 0.7159, saving model ...
epoch: 125/200, subject: 2/2, batch: 1/32, avg-batch-loss: nan, avg-batch-dice: nan
epoch: 125/200, subject: 2/2, batch: 2/32, avg-batch-loss: nan, avg-batch-dice: nan
epoch: 125/200, subject: 2/2, batch: 3/32, avg-batch-loss: nan, avg-batch-dice: nan

I am using two combined losses:

def focal_dice_loss(y_pred, y_true, delta = 0.7, gamma_fd=0.75, epsilon = 1e-6):
    axis = identify_axis(y_pred.shape) # [2,3,4]
    ones = torch.ones_like(y_pred)
    p_c = y_pred      # proba that voxels are class i
    p_n = ones-y_pred
    g_t = y_true #.type(torch.FloatTensor) #cuda.FloatTensor)
    g_n = ones-g_t
    tp = torch.sum(torch.sum(p_c*g_t, axis), 0)
    fp = torch.sum(torch.sum(p_c*g_n, axis), 0)
    fn = torch.sum(torch.sum(p_n*g_t, axis), 0)
    tversky_dice = (tp+epsilon)/(tp + delta*fn + (1-delta)*fp + epsilon) #torch.Size([9])
    focal_dice_loss_fg = torch.pow((1-tversky_dice), gamma_fd)[1:] # removing 0 --> background
    dice_loss = torch.sum(focal_dice_loss_fg)
    focal_dice_per_class = torch.mean(focal_dice_loss_fg)
    return dice_loss, focal_dice_loss_fg, focal_dice_per_class

def focal_loss(y_pred, y_true, clweight):
    # y_true = y_true.type(torch.cuda.FloatTensor)
    y_pred = torch.clamp(y_pred,  1e-6, 1-1e-6)
    cross_entropy = -y_true * torch.log(y_pred)
    floss = torch.mean( # 0 --> [9]
              torch.mean( # 2 --> [2, 9]
                torch.mean( # 3 --> [2, 9, 48]
                  torch.mean( # 4 --> [2, 9, 48, 50]
                    cross_entropy*torch.pow(1-y_pred, 2), # --> [2, 9, 48, 50, 64]
                  4),
                3),
              2),
            0)*clweight #.cuda()
    return torch.sum(floss)

I also trained with only the focal_loss above, which doesn’t give me the nan values. when I add (focal_loss + focal_dice_loss).backward() I get nan error after epoch 125.
Any help troubleshooting?

I checked the data I am working with, they are normalized within 0~1.

I guess a division or any other operation might create invalid outputs so you could add debug print statements to the code and check where the first NaN or Inf value is created.
Note that your model might generally blow up and the parameters might become invalid or your custom loss function might also create these invalid values.

Hi @ptrblck, I got a suggestion that changing Adam optimizer may help.

optimizer = torch.optim.RMSprop(model.parameters(), lr=2e-3)

with the optimizer above nan comes after 2 epochs!

Criteria at the end of epoch 1 subject 2 is 0.3492
criteria increased from 0.1109 to 0.3492, saving model ...
Nan found
Criteria at the end of epoch 2 subject 1 is 0.5875

Is it possible to find out what becomes nan first?

Also to avoid zero division error, I included epsilon value in all denominators and numerators as you can see in the custom loss functions.

Yes, that was the suggestion in my previous post. You could add print statements in the forward method and check, which activation gets these invalid values first to further isolate it.
Also, if the invalid values are created in the backward pass, you could use torch.autograd.set_detect_anomaly(True) and rerun the code.

Hi @ptrblck
I included the torch.autograd.set_detect_anomaly(True) in the script and also printed the output of focal_dice_loss_fg from the custom loss function. I guess it detects ‘nan’ in forward call:
fdloss, _, _ = focal_dice_loss(out, seg)

epoch: 1/200, subject: 1/2, batch: 31/32, avg-batch-loss: 7.7020, avg-batch-dice: 0.1117
tensor([1.0000, 1.0000, 1.0000, 1.0000, 0.9969, 0.9989, 0.9998, 1.0000],
       device='cuda:0', grad_fn=<SliceBackward>)
epoch: 1/200, subject: 1/2, batch: 32/32, avg-batch-loss: 7.7111, avg-batch-dice: 0.1084
Criteria at the end of epoch 1 subject 1 is 0.1084
criteria increased from 0.0000 to 0.1084, saving model ...
[W python_anomaly_mode.cpp:104] Warning: Error detected in PowBackward0. Traceback of forward call that caused the error:
  File "/home/banikr/PycharmProjects/HeadSeg/dummy.py", line 121, in <module>
    fdloss, _, _ = focal_dice_loss(out, seg)
  File "/home/banikr/PycharmProjects/HeadSeg/Utilities.py", line 391, in focal_dice_loss
    focal_dice_loss_fg = torch.pow((1-tversky_dice), gamma_fd)[1:] # removing 0 --> background
 (function _print_stack)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 0.9292, 1.0000, 0.0000, 0.0000],
       device='cuda:0', grad_fn=<SliceBackward>)
Traceback (most recent call last):
  File "/home/banikr/PycharmProjects/HeadSeg/dummy.py", line 129, in <module>
    fdloss.backward()
  File "/home/banikr/miniconda3/envs/HeadSeg36/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/banikr/miniconda3/envs/HeadSeg36/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

But
torch.pow((1 - torch.tensor(1.0)), 0.75)) shouldnt be a nan value. I dont see a reason for inf or abnormal dice value.
I will keep checking further. Just to share you the update.

The output in the forward pass will be zero, but the gradient would be NaN:

x = torch.tensor(1.0, requires_grad=True)
out = torch.pow((1 - x), 0.75)
print(out)
> tensor(0., grad_fn=<PowBackward0>)

out.backward()
print(x.grad)
> tensor(nan)
1 Like