During inference, batch norm
will be frozen. However, during training, it will be updated. To resolve this issue, you will need to explicitly freeze batch norm
during training. The best way to do that is by over-writing train()
method in your nn.Module
(aka model definition) so it will freeze batch norm
during training. Here is an example:
class DenseNetConv(torch.nn.Module):
def __init__(self):
super(DenseNetConv,self).__init__()
original_model = models.densenet161(pretrained=False)
self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
self.avgpool = nn.AdaptiveAvgPool2d(1)
for param in self.parameters():
param.requires_grad = False
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super(DenseNetConv, self).train(mode)
print("Freezing Mean/Var of BatchNorm2D.")
print("Freezing Weight/Bias of BatchNorm2D.")
for m in self.features.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False
def forward(self, x):
x = self.features(x)
x = F.relu(x, inplace=True)
x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
return x
Please refer to this thread for more information.