Mixed precision training with nn.SyncBatchNorm return NaN for running_var

I am training a model with mixed precision (amp_level=‘O1’) with a stack of 3D ResNet blocks.

class BasicBlock3d(nn.Module):
    """ 3x3x3 Resnet Basic Block"""
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None,dilation=1):
        super(BasicBlock3d, self).__init__()

        self.conv1 = conv3x3x3(inplanes, planes, stride, 1, dilation)
        self.bn1 = nn.SyncBatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes, 1, 1, dilation)
        self.bn2 = nn.SyncBatchNorm( planes)

        if downsample is not None:
            self.downsample = downsample
        elif stride != 1 or inplanes != planes:
            self.downsample = nn.Sequential(conv1x1x1(inplanes, planes, stride),  get_norm_3d(norm, planes))
        else:
            self.downsample = None

        self.stride = stride

    def forward(self, x):
     
        identity = x

        out1 = self.conv1(x)
        out2 = self.bn1(out1)
        out_ = self.relu(out2)

        out3 = self.conv2(out_)
        out = self.bn2(out3) 

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)


        return out

After training more than tens of epochs (70+), the self.bn1.running_var turns to be NaN in some middle blocks and crash the training pipeline (loss=NaN). I have found the direct reason is the input feature (out1) of self.bn1 got a sum larger than 65,504 in certain channel, which leads its mean becomes inf. One example is shown below

In[32]:out1[0,15,0,1,:]
Out[32]: 
tensor([14720., 21936., 21776., 21584., 21376., 21120., 20848., 20544., 20272.,
        19984., 19760., 19568., 19376., 19168., 18928., 18656., 18416., 18192.,
        18032., 17904., 17776., 17632., 17504., 17392., 17296., 17264., 17216.,
        17168., 17120., 17088., 17072., 17056., 17040., 17072., 17136., 17216.,
        17264., 17408., 17584., 17824., 17968., 18080., 18176., 18240., 18304.,
        18352., 18400., 18448., 18528., 18624., 18784., 18992., 19232., 19520.,
        19840., 20192., 20560., 20976., 21408., 21888., 22336., 22720., 22960.,
        15320.], dtype=torch.float16, grad_fn=<SliceBackward>)

In[33]:torch.mean(out1[0,15,0,1,:])
Out[33]: tensor(inf, dtype=torch.float16, grad_fn=<MeanBackward0>)

In[35]:torch.sum(out1[0,15,0,1,:])
Out[35]: tensor(inf, dtype=torch.float16, grad_fn=<SumBackward0>)

In[36]:torch.sum(out1[0,15,0,1,:].float())
Out[36]: tensor(1206136., grad_fn=<SumBackward0>)

In[38]:out2[0,15].isnan().all()
Out[38]: tensor(True)

Is there any way we can prevent the feature growing this large during the training process?

I meet the same problem.

The activation values are not clipped or manipulated in any way, so that the FP16 bounds can yield these NaN values. I’m not familiar with your workflow and would recommend to normalize the inputs etc. as is done in the majority of the models. This would keep the parameters as well as the activations in a reasonable range and would avoid running into the overflows.

CC @yft123