Hi @albanD. Thanks for providing useful insights on debugging this tricky issue! I have encountered the same problem but have been unable to overcome it.
I obtained the following traceback:
[W python_anomaly_mode.cpp:104] Warning: Error detected in MulBackward0. Traceback of forward call that caused the error:
File "train_eval.py", line 291, in <module>
start_epoch=0,
File "train_eval.py", line 126, in train_eval_model
s_pred_list = model(data_list, points_gt_list, edges_list, n_points_gt_list, perm_mat_list)
File "/home/user/envs/pytorch-gpu-1.7.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/user/code/matching/BB_GM/model.py", line 141, in forward
for gm_solver, unary_costs, quadratic_costs in zip(gm_solvers, unary_costs_list, quadratic_costs_list)
File "/home/user/code/matching/BB_GM/model.py", line 141, in <listcomp>
for gm_solver, unary_costs, quadratic_costs in zip(gm_solvers, unary_costs_list, quadratic_costs_list)
File "/home/user/envs/pytorch-gpu-1.7.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/user/code/matching/BB_GM/ADGM.py", line 448, in forward
tmp = ADGMWrapper(costs[0], costs[1], edges_left.T, edges_right.T, rounding=(not self.training), **self.solver_params)
File "/home/user/code/matching/BB_GM/ADGM.py", line 232, in ADGMWrapper
return ADGM(costs, P, rounding=rounding, **kargs)
File "/home/user/code/matching/BB_GM/ADGM.py", line 145, in ADGM
Z = X*temp2
(function _print_stack)
Traceback (most recent call last):
File "train_eval.py", line 291, in <module>
start_epoch=0,
File "train_eval.py", line 132, in train_eval_model
loss.backward()
File "/home/user/envs/pytorch-gpu-1.7.1/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/user/envs/pytorch-gpu-1.7.1/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.
Apparently the line that caused the issue is Z = X*temp2
in the code (simplified for convenience):
# Initialization
n1, n2 = U.shape
X = torch.zeros_like(U) + 1.0/n2
Z = torch.zeros_like(U) + 1.0/n1
Y = torch.zeros_like(U)
for i in range(iterations):
# Update X
...
temp = torch.exp(X - torch.max(X, dim=-1, keepdim=True)[0])
X = Z*temp
# Normalize: Sum of each row of X is 1
X = X / torch.max(X, dim=-1, keepdim=True)[0]
X = X / torch.sum(X, dim=-1, keepdim=True)
print(f'X normalized:\n {X}')
print(f'X normalized sum over row:\n {torch.sum(X, dim=-1)}')
# Update Z
...
temp2 = torch.exp(Z - torch.max(Z, dim=-2, keepdim=True)[0])
Z = X*temp2
# Normalize: Sum of each column of Z is 1
Z = Z / torch.max(Z, dim=-2, keepdim=True)[0]
Z = Z / torch.sum(Z, dim=-2, keepdim=True)
print(f'Z normalized:\n {Z}')
print(f'Z normalized sum over column:\n {torch.sum(Z, dim=-2)}')
# Update Y
...
The logs right before the traceback are the following:
X normalized:
tensor([[0.0000e+00, 4.2653e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 9.5735e-01, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.7397e-40, 0.0000e+00, 0.0000e+00, 9.9391e-01,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0904e-03],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 1.0000e+00, 0.0000e+00, 2.2040e-19, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 1.0000e+00, 0.0000e+00, 1.7096e-30, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 4.9784e-08, 1.4729e-09, 5.5091e-01, 1.1846e-12,
0.0000e+00, 0.0000e+00, 1.8706e-17, 1.5666e-26, 4.4909e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 1.0000e+00, 0.0000e+00, 1.6033e-32, 1.0499e-27],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 6.8453e-34, 2.1804e-33, 3.3585e-03,
0.0000e+00, 0.0000e+00, 0.0000e+00, 5.4462e-15, 9.9664e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 4.0123e-01, 0.0000e+00, 5.9877e-01, 0.0000e+00],
[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
6.6972e-33, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3273e-06,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.2851e-34,
0.0000e+00, 1.0000e+00, 0.0000e+00, 5.3407e-21, 6.9220e-31]],
device='cuda:0', grad_fn=<DivBackward0>)
X normalized sum over row:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
grad_fn=<SumBackward1>)
Z normalized:
tensor([[0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9989e-01,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.8221e-04],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 4.6598e-04, 0.0000e+00, 3.2422e-22, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 9.9953e-01, 0.0000e+00, 1.1537e-26, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0164e-10,
0.0000e+00, 0.0000e+00, 5.2710e-26, 2.1740e-20, 6.0694e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 6.0779e-07, 0.0000e+00, 7.5840e-38, 3.8613e-41],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.3132e-12,
0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9244e-19, 7.7734e-11],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 4.2213e-06, 0.0000e+00, 1.0000e+00, 0.0000e+00],
[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0853e-04,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.9288e-01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 5.3232e-07, 0.0000e+00, 9.5763e-26, 2.8026e-45]],
device='cuda:0', grad_fn=<DivBackward0>)
Z normalized sum over column:
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000], device='cuda:0', grad_fn=<SumBackward1>)
I also printed the values of temp2
for further investigation:
temp2:
tensor([[2.2421e-44, 1.0000e+00, 3.4224e-15, 3.8321e-19, 5.6191e-10, 5.0389e-06,
7.1892e-10, 1.3331e-01, 1.0000e+00, 1.0497e-04, 1.3752e-05],
[3.8088e-39, 7.9228e-24, 4.8935e-07, 9.9650e-15, 1.6200e-08, 1.1725e-02,
4.3550e-10, 2.2858e-04, 3.8238e-20, 1.4496e-07, 2.2137e-02],
[9.2906e-43, 1.9936e-23, 2.1098e-23, 2.9616e-23, 1.8445e-23, 2.6213e-15,
1.8063e-07, 4.6620e-04, 1.5636e-24, 1.0600e-09, 4.0346e-16],
[1.0970e-39, 1.4961e-20, 9.6151e-17, 7.5080e-16, 1.1495e-16, 2.6410e-10,
2.4117e-06, 1.0000e+00, 1.6358e-17, 4.8629e-03, 2.9368e-10],
[7.4011e-41, 5.9593e-24, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
3.8696e-10, 1.3359e-06, 2.6977e-09, 1.0000e+00, 1.0000e+00],
[1.9164e-41, 1.7378e-23, 6.1811e-22, 1.3114e-25, 6.0049e-23, 1.2770e-14,
2.4556e-12, 6.0808e-07, 3.7422e-24, 3.4006e-12, 2.7213e-14],
[4.8908e-35, 4.5041e-33, 9.3833e-18, 4.9370e-20, 5.4802e-17, 1.4968e-11,
7.7122e-11, 5.6934e-12, 4.8426e-30, 3.8693e-11, 5.7712e-11],
[7.4685e-38, 6.6625e-31, 3.3640e-21, 3.4035e-15, 3.7193e-24, 2.0841e-16,
1.8670e-04, 1.0526e-05, 2.1427e-28, 1.2035e-06, 2.4145e-17],
[1.0000e+00, 0.0000e+00, 1.8710e-41, 1.4928e-38, 0.0000e+00, 6.2585e-38,
1.0000e+00, 7.3923e-30, 0.0000e+00, 5.2951e-21, 6.4194e-40],
[1.3871e-40, 2.5689e-23, 7.0096e-08, 3.0429e-13, 1.6090e-07, 3.8015e-01,
1.7917e-10, 6.8646e-05, 7.7878e-19, 4.3145e-07, 2.9071e-01],
[1.2150e-36, 1.0209e-26, 4.6150e-22, 4.5896e-25, 1.7469e-23, 1.8506e-14,
2.4162e-11, 5.3257e-07, 1.5404e-26, 1.2921e-11, 3.5527e-15]],
device='cuda:0', grad_fn=<ExpBackward>)
It seems to me that both X
and temp2
look numerically good, yet the operation Z = X*temp2
caused NaN values in the 1th output (i.e., the derivative w.r.t. temp2
, which is X
). Would you have any ideas to fix this please?
Thank you very much in advance for your help!