NAN of batchnorm in AMP

Hi, we have a model where some layers are

...
self.shared_1 = nn.Conv2d(512, 64, kernel_size=3, padding=1, bias=False)
self.shared_2 = nn.Batchnorm2d(64, eps=1e-3)
self.shared_3 = nn.ReLU(True)
...

The forward is

x1 = self.shared_1(x)
x2 = self.shared_2(x1)
x3 = self.shared_3(x2)

When we using AMP, we found that the running_mean and runnning_var of self.shared_2 are normal before receiving x1. After producing x2, however, there are some NANs in its running_mean and runnning_var. Maybe these values calculated from x1 exceed the range of half precision?

Are there any elegant approach to avoid NAN? Thanks

For instance, how could we keep above conv layer as FP32?

This should not be the case since autocast will use float32 in batchnorm layers.
Most likely x1 already contained invalid values (Infs or NaNs) and is thus updating the running stats with invalid values, too.
Did you make sure the input does not contain any invalid values?

We check the min, max, mean values of x1 which are all normal.

Besides, the eps values in all codes are set to 1e-3, as well as the torch.log(). So, does the calculation of running_mean and running_var in bn is still fp16, which exceeds range?

No, the batchnorm layers will use float32 inside the autocast context as already explained. Unless you are calling half() on this model and use a manual flaot16 training, this operation should not overflow.

We know that the autocast(enable=False) could be used in forward() to set specific layers as FP32. Are there some approaches to keep specific layers as FP32 in __init__ function?

Maybe we could find bug layer by progressively replacing FP16 layers as FP32…

@sakura Please support this simple autocast-module-wrappers feature request at [feature request] Autocast module and function wrappers · Issue #70166 · pytorch/pytorch · GitHub, so far it hasn’t obtained much support/traction :frowning:

Probably it could be as simple as:

class AutocastModule(torch.nn.Module):
  def __init__(self, module, **autocast_options):
    super().__init__()
    self.module = module
    self.autocast_options = autocast_options

  def forward(self, *args, **kwargs):
    with torch.autocast(**self.autocast_options):
      return self.module(*args, **kwargs)

Thanks. The bug has been solved : )

I got same bug on my model,how did you solve this problem?
When training, every think is normal, but when it comes to val, the running_mean and var become nan

One solution is to do manual updates of running stats in separate, manually managed tensors. Like this, you can filter out the updates/bad values that you don’t want.

One can do it by setting bn.momentum = 1, and then bn.running_mean/bn.running_var would actually contain batch statistics. Then you would use these batch statistics to update the running statistics tensors that you keep on the side. Before evaluation, you need to copy running statistics into bn.running_mean/bn.running_var