NAN of batchnorm in AMP

Hi, we have a model where some layers are

self.shared_1 = nn.Conv2d(512, 64, kernel_size=3, padding=1, bias=False)
self.shared_2 = nn.Batchnorm2d(64, eps=1e-3)
self.shared_3 = nn.ReLU(True)

The forward is

x1 = self.shared_1(x)
x2 = self.shared_2(x1)
x3 = self.shared_3(x2)

When we using AMP, we found that the running_mean and runnning_var of self.shared_2 are normal before receiving x1. After producing x2, however, there are some NANs in its running_mean and runnning_var. Maybe these values calculated from x1 exceed the range of half precision?

Are there any elegant approach to avoid NAN? Thanks

For instance, how could we keep above conv layer as FP32?

This should not be the case since autocast will use float32 in batchnorm layers.
Most likely x1 already contained invalid values (Infs or NaNs) and is thus updating the running stats with invalid values, too.
Did you make sure the input does not contain any invalid values?

We check the min, max, mean values of x1 which are all normal.

Besides, the eps values in all codes are set to 1e-3, as well as the torch.log(). So, does the calculation of running_mean and running_var in bn is still fp16, which exceeds range?

No, the batchnorm layers will use float32 inside the autocast context as already explained. Unless you are calling half() on this model and use a manual flaot16 training, this operation should not overflow.

We know that the autocast(enable=False) could be used in forward() to set specific layers as FP32. Are there some approaches to keep specific layers as FP32 in __init__ function?

Maybe we could find bug layer by progressively replacing FP16 layers as FP32…

@sakura Please support this simple autocast-module-wrappers feature request at [feature request] Autocast module and function wrappers · Issue #70166 · pytorch/pytorch · GitHub, so far it hasn’t obtained much support/traction :frowning:

Probably it could be as simple as:

class AutocastModule(torch.nn.Module):
  def __init__(self, module, **autocast_options):
    self.module = module
    self.autocast_options = autocast_options

  def forward(self, *args, **kwargs):
    with torch.autocast(**self.autocast_options):
      return self.module(*args, **kwargs)

Thanks. The bug has been solved : )