Hi, here I want to sync the training state of running statistics with weights in batchnorm layers, so we override the train() method in the following form. But it seems doesn’t take effect.
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
padding = 2 - stride
if downsample is not None and dilation > 1:
dilation = dilation // 2
padding = dilation
assert stride == 1 or dilation == 1, "stride and dilation must have one equals to zero at least"
if dilation > 1:
padding = dilation
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=stride,
padding=padding,
bias=False,
dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
if out.size() != residual.size():
print(out.size(), residual.size())
out += residual
out = self.relu(out)
return out
def train(self, mode):
if mode:
for name, module in self.named_children():
if isinstance(module, _BatchNorm):
for p_name, param in module.named_parameters():
if param.requires_grad is False:
print("-"*10, "frozen bn module found", "-"*10)
mode = False
break
if not mode:
break
super().train(mode)
return self
To reproduce
test_module = Bottleneck(2, 2)
for name, module in test_module.named_children():
if 'bn' in name:
print("module {} training state: ".format(name), module.training)
for name, param in test_module.named_parameters():
if 'bn' in name:
print("set param {} requires_grad False".format(name))
param.requires_grad = False
test_module.train(True)
print(test_module.training)
for name, module in test_module.named_children():
if 'bn' in name:
print("module {} training state: ".format(name), module.training)