How could a BatchNorm layer generate nan in backward process?

I’m training a customized Transformer modle with my customized loss function, i found loss value is nan after some training steps, so I set torch.autograd.set_detect_anomaly(True) try to figure out what happend. after adding this line, It gives this error:

[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward. Traceback of forward call that caused the error:

.... ommited ...

  File "/home/rjia/dl/seq_release/research/", line 389, in forward
    x_preprocessed = self.preprocess_module(X)
  File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/rjia/dl/seq_release/research/", line 915, in forward
    rtn = self.batchnorm(X.view(bz * seq_len, m)).view(bz, seq_len, m)
  File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/modules/", line 131, in forward
    return F.batch_norm(
  File "/home/rjia/anaconda3/envs/cuda102/lib/python3.8/site-packages/torch/nn/", line 2056, in batch_norm
    return torch.batch_norm(

and finally raise RuntimeError("Function 'CudnnBatchNormBackward' returned nan values in its 1th output.")

I’v serarch in google and pytorch forum, eg. . some people memtioned that “if all inputs are same , or exist nan, or batch_size == 1, then batchnorm will generate nan value in backward process” but I’v tried to feed all-zero inputs to BatchNorm1d() layer and backward() works well in this case:

import torch

m = 5
n = 8

bn = torch.nn.BatchNorm1d(num_features=m)

y = torch.zeros(n, 1)
x = torch.zeros(n, m)  # <----- all inputs are same do not cause batchnorm raise Error
w = torch.nn.Parameter(data=torch.rand(1, n))

for i in range(1000):
    out = bn(x)
    yh = torch.matmul(w, x)
    loss = ((yh - y) ** 2).sum()
    assert ~torch.isnan(bn.running_mean).any()
    assert ~torch.isnan(bn.running_var).any()

So my question is: In which case could BatchNorm generate nan values?

1 Like

BatchNorm layers shouldn’t return NaN values, if the input is well defined.
You could add assert statements to the input and output of this batchnorm layer and check if all values are finite via torch.isfinite(tensor).

1 Like