Switch off batch norm layers

During inference, batch norm will be frozen. However, during training, it will be updated. To resolve this issue, you will need to explicitly freeze batch norm during training. The best way to do that is by over-writing train() method in your nn.Module (aka model definition) so it will freeze batch norm during training. Here is an example:

class DenseNetConv(torch.nn.Module):
    def __init__(self):
        super(DenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False
    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super(DenseNetConv, self).train(mode)

        print("Freezing Mean/Var of BatchNorm2D.")
        print("Freezing Weight/Bias of BatchNorm2D.")

        for m in self.features.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False
    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

Please refer to this thread for more information.

1 Like