I have the following DenseNet that extracts features from an untrained network and feeds the tensors into a fully connected model as shown by the following:
class DenseNetConv(torch.nn.Module):
def __init__(self):
super(DenseNetConv,self).__init__()
original_model = models.densenet161(pretrained=False)
self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
self.avgpool = nn.AdaptiveAvgPool2d(1)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
x = self.features(x)
x = F.relu(x, inplace=True)
x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
return x
class MyDenseNetDens(torch.nn.Module):
def __init__(self, nb_out=2):
super().__init__()
self.dens1 = torch.nn.Linear(in_features=2208, out_features=512)
self.dens2 = torch.nn.Linear(in_features=512, out_features=128)
self.dens3 = torch.nn.Linear(in_features=128, out_features=nb_out)
def forward(self, x):
x = self.dens1(x)
x = torch.nn.functional.selu(x)
x = F.dropout(x, p=0.25, training=self.training)
x = self.dens2(x)
x = torch.nn.functional.selu(x)
x = F.dropout(x, p=0.25, training=self.training)
x = self.dens3(x)
return x
class MyDenseNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.mrnc = MyDenseNetConv()
self.mrnd = MyDenseNetDens()
def forward(self, x):
x = self.mrnc(x)
x = self.mrnd(x)
return x
densenet = MyDenseNet()
densenet.to(device)
densenet.train()
However, what is the best way of switching off the batch norm layers in this model for training and inference? I would typically put the feature extraction bit into evaluation mode but wondering if there is a cleaner way as I am combining the feature extractor and dense classification layers in one model…
During inference, batch norm will be frozen. However, during training, it will be updated. To resolve this issue, you will need to explicitly freeze batch norm during training. The best way to do that is by over-writing train() method in your nn.Module (aka model definition) so it will freeze batch norm during training. Here is an example:
class DenseNetConv(torch.nn.Module):
def __init__(self):
super(DenseNetConv,self).__init__()
original_model = models.densenet161(pretrained=False)
self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
self.avgpool = nn.AdaptiveAvgPool2d(1)
for param in self.parameters():
param.requires_grad = False
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super(DenseNetConv, self).train(mode)
print("Freezing Mean/Var of BatchNorm2D.")
print("Freezing Weight/Bias of BatchNorm2D.")
for m in self.features.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False
def forward(self, x):
x = self.features(x)
x = F.relu(x, inplace=True)
x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
return x
model = models.densenet161(pretrained=False)
for param in model.parameters():
param.requires_grad = False
num_ftrs = model.classifier.in_features
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False
model.classifier = torch.nn.Linear(num_ftrs,2)
print(model.classifier)
model.to(device)
Why are you freezing your backbone? Especially if it isn’t pre-trained.
One more thing, it is better to have a little bit deeper MLP (3 layers) with ReLU in between them.
densenet161 -> flatten -> linear (relu) [num_ftrs, 256] -> linear (relu) [256, 256] -> linear [256, 2]
Don’t freeze your backbone and batchnorm if your model wasn’t trained because you will have some inefficient feature extractor, and the MLP can’t have a good representation to learn from for your task.
The experiment I am attempting is to investigate how effective the convolutional layers are when they are randomly initialised. Previous research appears to suggest that convnets can be good classifiers even when the convolutional layers are not trained:
I basically want to see how important batch normalisation is in regard to my classification task.
And yes, I agree that a deeper MLP would be beneficial but I am following the vanilla example provided by PyTorch’s tutorial on transfer learning: