I added some more hooks and prints to the function, which I pasted at the bottom. It seems that setting set_detect_anomaly
on prevents any nan detection (all the outputs are False). Is this expected?
With set_detect_anomaly(False)
, the output is
x_1 False
x_2 False
y_1 False
y_2 False
tv_y back False
y_sqrt back False
y_2 back False
diff_4 back True
y_1 back False
diff_3 back True
x_2 back False
diff_2 back False
x_1 back False
diff_1 back False
Epoch 211:
x_1 True
x_2 True
y_1 True
y_2 True
tv_y back False
y_sqrt back False
y_2 back True
diff_4 back True
y_1 back True
diff_3 back True
x_2 back True
diff_2 back True
x_1 back True
diff_1 back True
With set_detect_anomaly(True)
, the output is
x_1 False
x_2 False
y_1 False
y_2 False
tv_y back False
y_sqrt back False
y_2 back False
diff_4 back False
y_1 back False
diff_3 back False
x_2 back False
diff_2 back False
x_1 back False
diff_1 back False
Epoch 211:
x_1 False
x_2 False
y_1 False
y_2 False
tv_y back False
y_sqrt back False
y_2 back False
/opt/conda/conda-bld/pytorch_1570711283072/work/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
File "run.py", line 84, in <module>
train.train(xdata, gt_data, vars(args))
...
File "...", line 119, in unsup_loss
tv = get_tv(x_hat, phase)
File "...", line 51, in get_tv
y_2 = torch.pow(diff, 2)
Traceback (most recent call last):
File "run.py", line 84, in <module>
train.train(xdata, gt_data, vars(args))
...
File "...", line 118, in trainer
loss.backward()
File "/nfs01/shared_software/anaconda3/envs/aw847-py37/lib/python3.7/site-packages/torch/tensor.py", line 150, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/nfs01/shared_software/anaconda3/envs/aw847-py37/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
For my problem, it seems that the error is when calculating the diff. Why is a subtraction operator giving an nan?
Full function:
def get_tv(x, phase):
findnan = lambda s, x: print(s, (x != x).any().item())
x = x.float()
befor = x[:, 0, :, :-1]
after = x[:, 0, :, 1:]
diff = befor-after
if phase == 'train':
diff.register_hook(lambda grad: print('diff_1 back', (grad != grad).any().item()))
x_1 = torch.pow(diff, 2)
findnan('x_1', x_1)
x_1.register_hook(lambda grad: print('x_1 back', (grad != grad).any().item()))
befor = x[:, 1, :, :-1]
after = x[:, 1, :, 1:]
diff = befor-after
diff.register_hook(lambda grad: print('diff_2 back', (grad != grad).any().item()))
x_2 = torch.pow(diff, 2)
findnan('x_2', x_2)
x_2.register_hook(lambda grad: print('x_2 back', (grad != grad).any().item()))
x_sqrt = torch.sqrt(x_1 + x_2)
tv_x = torch.sum(x_sqrt)
befor = x[:, 0, :-1, :]
after = x[:, 0, 1:, :]
diff = befor-after
diff.register_hook(lambda grad: print('diff_3 back', (grad != grad).any().item()))
y_1 = torch.pow(diff, 2)
findnan('y_1', y_1)
y_1.register_hook(lambda grad: print('y_1 back', (grad != grad).any().item()))
befor = x[:, 1, :-1, :]
after = x[:, 1, 1:, :]
diff = befor-after
diff.register_hook(lambda grad: print('diff_4 back', (grad != grad).any().item()))
y_2 = torch.pow(diff, 2)
findnan('y_2', y_2)
y_2.register_hook(lambda grad: print('y_2 back', (grad != grad).any().item()))
y_sqrt = torch.sqrt(y_1 + y_2)
y_sqrt.register_hook(lambda grad: print('y_sqrt back', (grad != grad).any().item()))
tv_y = torch.sum(y_sqrt)
tv_y.register_hook(lambda grad: print('tv_y back', (grad != grad).any().item()))
return tv_x + tv_y