Switch off batch norm layers

I have the following DenseNet that extracts features from an untrained network and feeds the tensors into a fully connected model as shown by the following:

class DenseNetConv(torch.nn.Module):
    def __init__(self):
        super(DenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

class MyDenseNetDens(torch.nn.Module):
    def __init__(self, nb_out=2):
        super().__init__()
        self.dens1 = torch.nn.Linear(in_features=2208, out_features=512)
        self.dens2 = torch.nn.Linear(in_features=512, out_features=128)
        self.dens3 = torch.nn.Linear(in_features=128, out_features=nb_out)
        
    def forward(self, x):
        x = self.dens1(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens2(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens3(x)
        return x


class MyDenseNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mrnc = MyDenseNetConv()
        self.mrnd = MyDenseNetDens()
    def forward(self, x):
        x = self.mrnc(x)
        x = self.mrnd(x)
        return x 

densenet = MyDenseNet()
densenet.to(device)
densenet.train()

However, what is the best way of switching off the batch norm layers in this model for training and inference? I would typically put the feature extraction bit into evaluation mode but wondering if there is a cleaner way as I am combining the feature extractor and dense classification layers in one model…

During inference, batch norm will be frozen. However, during training, it will be updated. To resolve this issue, you will need to explicitly freeze batch norm during training. The best way to do that is by over-writing train() method in your nn.Module (aka model definition) so it will freeze batch norm during training. Here is an example:

class DenseNetConv(torch.nn.Module):
    def __init__(self):
        super(DenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False
    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super(DenseNetConv, self).train(mode)

        print("Freezing Mean/Var of BatchNorm2D.")
        print("Freezing Weight/Bias of BatchNorm2D.")

        for m in self.features.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False
    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

Please refer to this thread for more information.

1 Like

Ah ok, so during inference it will be frozen automatically. As I am not using ImageNet, what would be the running stats applied at inference then?

1 Like

It will be what it was trained on. In your case, your training dataset stats.

1 Like

ah ok and so if the batch norm layers are switched off during training, basically my running stats will be nothing…

Would this be appropriate too?

model = models.densenet161(pretrained=False)
    for param in model.parameters():
        param.requires_grad = False
    num_ftrs = model.classifier.in_features

    for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False
    
    model.classifier = torch.nn.Linear(num_ftrs,2)
    print(model.classifier)
    model.to(device)

Why are you freezing your backbone? Especially if it isn’t pre-trained.

One more thing, it is better to have a little bit deeper MLP (3 layers) with ReLU in between them.

densenet161 -> flatten -> linear (relu) [num_ftrs, 256] -> linear (relu) [256, 256] -> linear [256, 2]

Don’t freeze your backbone and batchnorm if your model wasn’t trained because you will have some inefficient feature extractor, and the MLP can’t have a good representation to learn from for your task.

1 Like

The experiment I am attempting is to investigate how effective the convolutional layers are when they are randomly initialised. Previous research appears to suggest that convnets can be good classifiers even when the convolutional layers are not trained:

I basically want to see how important batch normalisation is in regard to my classification task.

And yes, I agree that a deeper MLP would be beneficial but I am following the vanilla example provided by PyTorch’s tutorial on transfer learning:

https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

1 Like