I am training a model with mixed precision (amp_level=‘O1’) with a stack of 3D ResNet blocks.
class BasicBlock3d(nn.Module):
""" 3x3x3 Resnet Basic Block"""
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None,dilation=1):
super(BasicBlock3d, self).__init__()
self.conv1 = conv3x3x3(inplanes, planes, stride, 1, dilation)
self.bn1 = nn.SyncBatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3x3(planes, planes, 1, 1, dilation)
self.bn2 = nn.SyncBatchNorm( planes)
if downsample is not None:
self.downsample = downsample
elif stride != 1 or inplanes != planes:
self.downsample = nn.Sequential(conv1x1x1(inplanes, planes, stride), get_norm_3d(norm, planes))
else:
self.downsample = None
self.stride = stride
def forward(self, x):
identity = x
out1 = self.conv1(x)
out2 = self.bn1(out1)
out_ = self.relu(out2)
out3 = self.conv2(out_)
out = self.bn2(out3)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
After training more than tens of epochs (70+), the self.bn1.running_var turns to be NaN in some middle blocks and crash the training pipeline (loss=NaN). I have found the direct reason is the input feature (out1) of self.bn1 got a sum larger than 65,504 in certain channel, which leads its mean becomes inf. One example is shown below
In[32]:out1[0,15,0,1,:]
Out[32]:
tensor([14720., 21936., 21776., 21584., 21376., 21120., 20848., 20544., 20272.,
19984., 19760., 19568., 19376., 19168., 18928., 18656., 18416., 18192.,
18032., 17904., 17776., 17632., 17504., 17392., 17296., 17264., 17216.,
17168., 17120., 17088., 17072., 17056., 17040., 17072., 17136., 17216.,
17264., 17408., 17584., 17824., 17968., 18080., 18176., 18240., 18304.,
18352., 18400., 18448., 18528., 18624., 18784., 18992., 19232., 19520.,
19840., 20192., 20560., 20976., 21408., 21888., 22336., 22720., 22960.,
15320.], dtype=torch.float16, grad_fn=<SliceBackward>)
In[33]:torch.mean(out1[0,15,0,1,:])
Out[33]: tensor(inf, dtype=torch.float16, grad_fn=<MeanBackward0>)
In[35]:torch.sum(out1[0,15,0,1,:])
Out[35]: tensor(inf, dtype=torch.float16, grad_fn=<SumBackward0>)
In[36]:torch.sum(out1[0,15,0,1,:].float())
Out[36]: tensor(1206136., grad_fn=<SumBackward0>)
In[38]:out2[0,15].isnan().all()
Out[38]: tensor(True)
Is there any way we can prevent the feature growing this large during the training process?