nn.Module.train(False) doesn't task effect

Hi, here I want to sync the training state of running statistics with weights in batchnorm layers, so we override the train() method in the following form. But it seems doesn’t take effect.

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        padding = 2 - stride
        if downsample is not None and dilation > 1:
            dilation = dilation // 2
            padding = dilation

        assert stride == 1 or dilation == 1, "stride and dilation must have one equals to zero at least"

        if dilation > 1:
            padding = dilation
        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=padding,
                               bias=False,
                               dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        if out.size() != residual.size():
            print(out.size(), residual.size())
        out += residual

        out = self.relu(out)

        return out

    def train(self, mode):
        if mode:
            for name, module in self.named_children():
                if isinstance(module, _BatchNorm):
                    for p_name, param in module.named_parameters():      
                        if param.requires_grad is False:
                            print("-"*10, "frozen bn module found", "-"*10)
                            mode = False
                            break
                if not mode:
                    break
        super().train(mode)
        return self

To reproduce

test_module = Bottleneck(2, 2)
for name, module in test_module.named_children():
    if 'bn' in name:
        print("module {} training state: ".format(name), module.training)

for name, param in test_module.named_parameters():
    if 'bn' in name:
        print("set param {} requires_grad False".format(name))
        param.requires_grad = False
test_module.train(True)
print(test_module.training)
for name, module in test_module.named_children():
    if 'bn' in name:
        print("module {} training state: ".format(name), module.training)