Hi, here I want to sync the training state of running statistics with weights in batchnorm layers, so we override the train() method in the following form. But it seems doesn’t take effect.

``````class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
if downsample is not None and dilation > 1:
dilation = dilation // 2

assert stride == 1 or dilation == 1, "stride and dilation must have one equals to zero at least"

if dilation > 1:
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=stride,
bias=False,
dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

if out.size() != residual.size():
print(out.size(), residual.size())
out += residual

out = self.relu(out)

return out

def train(self, mode):
if mode:
for name, module in self.named_children():
if isinstance(module, _BatchNorm):
for p_name, param in module.named_parameters():
print("-"*10, "frozen bn module found", "-"*10)
mode = False
break
if not mode:
break
super().train(mode)
return self
``````

To reproduce

``````test_module = Bottleneck(2, 2)
for name, module in test_module.named_children():
if 'bn' in name:
print("module {} training state: ".format(name), module.training)

for name, param in test_module.named_parameters():
if 'bn' in name: