I’m training a customized Transformer modle with my customized loss function, i found loss value is nan after some training steps, so I set torch.autograd.set_detect_anomaly(True)
try to figure out what happend. after adding this line, It gives this error:
[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward. Traceback of forward call that caused the error:
.... ommited ...
File "/home/rjia/dl/seq_release/research/seq_models.py", line 389, in forward
x_preprocessed = self.preprocess_module(X)
File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/rjia/dl/seq_release/research/seq_models.py", line 915, in forward
rtn = self.batchnorm(X.view(bz * seq_len, m)).view(bz, seq_len, m)
File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 131, in forward
return F.batch_norm(
File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/functional.py", line 2056, in batch_norm
return torch.batch_norm(
and finally raise RuntimeError("Function 'CudnnBatchNormBackward' returned nan values in its 1th output.")
I’v serarch in google and pytorch forum, eg. . some people memtioned that “if all inputs are same , or exist nan, or batch_size == 1, then batchnorm will generate nan value in backward process” but I’v tried to feed all-zero inputs to BatchNorm1d()
layer and backward() works well in this case:
import torch
torch.autograd.set_detect_anomaly(True)
m = 5
n = 8
bn = torch.nn.BatchNorm1d(num_features=m)
y = torch.zeros(n, 1)
x = torch.zeros(n, m) # <----- all inputs are same do not cause batchnorm raise Error
w = torch.nn.Parameter(data=torch.rand(1, n))
for i in range(1000):
out = bn(x)
yh = torch.matmul(w, x)
loss = ((yh - y) ** 2).sum()
loss.backward()
assert ~torch.isnan(bn.running_mean).any()
assert ~torch.isnan(bn.running_var).any()
So my question is: In which case could BatchNorm generate nan values?