Nan in backward pass for torch.square()

When using detect_anomoly, I’m getting an nan in the backward pass of a squaring function. This confuses me because both the square and its derivative should not give nans at any point. I’ve checked that the nan arises in the backward pass and not the forward pass. Am I missing something here?

Here’s the full error:
The forward pass call where the error occurs is

y_2 = torch.pow(diff, 2) 

and the specific error is

RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

Hi,

Does diff contain some pathological values?
Does the gradient flowing back does? You can print it by doing:

y_2 = torch.pow(diff, 2) 
y_2.register_hook(print)  # This will print the gradient when it's computed.

Hi albanD, thanks for the response. Here’s the full code, I’m doing a simple total variation function that can be found here: https://www.wikiwand.com/en/Total_variation_denoising. I’ve checked for nans in the forward pass as well as using a hook for the backward pass, and for some reason there are no nans in either pass.

def get_tv(x, phase):                                                                                                                                                                                              
    findnan = lambda s, x: print(s, (x != x).any().item())

    x = x.float()
    befor = x[:, 0, :, :-1]
    after = x[:, 0, :, 1:] 
    diff = befor-after
    x_1 = torch.pow(diff, 2)
    findnan('x_1', x_1)
    x_1.register_hook(lambda grad: print('x_1 back', (grad != grad).any().item()))

    befor = x[:, 1, :, :-1]
    after = x[:, 1, :, 1:] 
    diff = befor-after
    x_2 = torch.pow(diff, 2)
    findnan('x_2', x_2)
    x_2.register_hook(lambda grad: print('x_2 back', (grad != grad).any().item()))

    x_sqrt = torch.sqrt(x_1 + x_2)
    tv_x = torch.sum(x_sqrt)

    befor = x[:, 0, :-1, :]
    after = x[:, 0, 1:, :]
    diff = befor-after
    y_1 = torch.pow(diff, 2)
    findnan('y_1', y_1)
    y_1.register_hook(lambda grad: print('y_1 back', (grad != grad).any().item()))

    befor = x[:, 1, :-1, :]
    after = x[:, 1, 1:, :]
    diff = befor-after
    y_2 = torch.pow(diff, 2)
    findnan('y_2', y_2)
    y_2.register_hook(lambda grad: print('y_2 back', (grad != grad).any().item()))

    y_sqrt = torch.sqrt(y_1 + y_2)
    tv_y = torch.sum(y_sqrt)
    return tv_x + tv_y

And here’s the full error output:

x_1 False
x_2 False
y_1 False
y_2 False
y_2 back False
/opt/conda/conda-bld/pytorch_1570711283072/work/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
  File "run.py", line 84, in <module>
    train.train(xdata, gt_data, vars(args))
  ...
  File "...", line 119, in unsup_loss
    tv = get_tv(x_hat, phase)
  File "...", line 51, in get_tv
    y_2 = torch.pow(diff, 2)

Traceback (most recent call last):
  File "run.py", line 84, in <module>
    train.train(xdata, gt_data, vars(args))
  ...
  File "...", line 118, in trainer
    loss.backward()
  File "/nfs01/shared_software/anaconda3/envs/aw847-py37/lib/python3.7/site-packages/torch/tensor.py", line 150, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/nfs01/shared_software/anaconda3/envs/aw847-py37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

Can you add a hook on the input of this pow() and see what the gradient actually is when you de-active anomaly mode?

I added some more hooks and prints to the function, which I pasted at the bottom. It seems that setting set_detect_anomaly on prevents any nan detection (all the outputs are False). Is this expected?

With set_detect_anomaly(False), the output is

x_1 False
x_2 False
y_1 False
y_2 False
tv_y back False
y_sqrt back False
y_2 back False
diff_4 back True
y_1 back False
diff_3 back True
x_2 back False
diff_2 back False
x_1 back False
diff_1 back False

Epoch 211:
x_1 True
x_2 True
y_1 True
y_2 True
tv_y back False
y_sqrt back False
y_2 back True
diff_4 back True
y_1 back True
diff_3 back True
x_2 back True
diff_2 back True
x_1 back True
diff_1 back True

With set_detect_anomaly(True), the output is

x_1 False
x_2 False
y_1 False
y_2 False
tv_y back False
y_sqrt back False
y_2 back False
diff_4 back False
y_1 back False
diff_3 back False
x_2 back False
diff_2 back False
x_1 back False
diff_1 back False

Epoch 211:
x_1 False
x_2 False
y_1 False
y_2 False
tv_y back False
y_sqrt back False
y_2 back False
/opt/conda/conda-bld/pytorch_1570711283072/work/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
  File "run.py", line 84, in <module>
    train.train(xdata, gt_data, vars(args))
  ...
  File "...", line 119, in unsup_loss
    tv = get_tv(x_hat, phase)
  File "...", line 51, in get_tv
    y_2 = torch.pow(diff, 2)

Traceback (most recent call last):
  File "run.py", line 84, in <module>
    train.train(xdata, gt_data, vars(args))
  ...
  File "...", line 118, in trainer
    loss.backward()
  File "/nfs01/shared_software/anaconda3/envs/aw847-py37/lib/python3.7/site-packages/torch/tensor.py", line 150, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/nfs01/shared_software/anaconda3/envs/aw847-py37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

For my problem, it seems that the error is when calculating the diff. Why is a subtraction operator giving an nan?

Full function:

def get_tv(x, phase):
    findnan = lambda s, x: print(s, (x != x).any().item())

    x = x.float()
    befor = x[:, 0, :, :-1]
    after = x[:, 0, :, 1:]
    diff = befor-after
    if phase == 'train':                                                                                                                                                                                                                                                        
        diff.register_hook(lambda grad: print('diff_1 back', (grad != grad).any().item()))
    x_1 = torch.pow(diff, 2)
    findnan('x_1', x_1)
    x_1.register_hook(lambda grad: print('x_1 back', (grad != grad).any().item()))

    befor = x[:, 1, :, :-1]
    after = x[:, 1, :, 1:]
    diff = befor-after
    diff.register_hook(lambda grad: print('diff_2 back', (grad != grad).any().item()))
    x_2 = torch.pow(diff, 2)
    findnan('x_2', x_2)
    x_2.register_hook(lambda grad: print('x_2 back', (grad != grad).any().item()))

    x_sqrt = torch.sqrt(x_1 + x_2)
    tv_x = torch.sum(x_sqrt)

    befor = x[:, 0, :-1, :]
    after = x[:, 0, 1:, :]
    diff = befor-after
    diff.register_hook(lambda grad: print('diff_3 back', (grad != grad).any().item()))
    y_1 = torch.pow(diff, 2)
    findnan('y_1', y_1)
    y_1.register_hook(lambda grad: print('y_1 back', (grad != grad).any().item()))

    befor = x[:, 1, :-1, :]
    after = x[:, 1, 1:, :]
    diff = befor-after
    diff.register_hook(lambda grad: print('diff_4 back', (grad != grad).any().item()))
    y_2 = torch.pow(diff, 2)
    findnan('y_2', y_2)
    y_2.register_hook(lambda grad: print('y_2 back', (grad != grad).any().item()))

    y_sqrt = torch.sqrt(y_1 + y_2)
    y_sqrt.register_hook(lambda grad: print('y_sqrt back', (grad != grad).any().item()))
    tv_y = torch.sum(y_sqrt)
    tv_y.register_hook(lambda grad: print('tv_y back', (grad != grad).any().item()))
    return tv_x + tv_y

inf - inf would give nan IIRC.

It seems that setting set_detect_anomaly on prevents any nan detection (all the outputs are False). Is this expected?

No but maybe your code is “flaky” and depending on the run, the nan don’t appear at the same place?

Yes, inf-inf was exactly the issue! Thanks for the help.

1 Like