I am trying to fine tune the pretrained resnet model provided by torchvision, and I need to remove serveral batchnorm layers before implementing the fine-tune. How could I close these BN layers?
This code should deactivate the
def deactivate_batchnorm(m): if isinstance(m, nn.BatchNorm2d): m.reset_parameters() m.eval() with torch.no_grad(): m.weight.fill_(1.0) m.bias.zero_() model = nn.Sequential( nn.Conv2d(3, 6, 3, 1, 1), nn.BatchNorm2d(6) ) x = torch.randn(10, 3, 24, 24) output1 = model(x) output2 = model(x) print(torch.allclose(output1, output2)) > False model.apply(deactivate_batchnorm) output1 = model(x) output2 = model(x) torch.allclose(output1, output2) > True
Keep in mind, that you would need to call this function again after switching the whole model between
Another approach would be to completely remove the batch norm layers and recreate the model, which might be quite complicated based on your model and
Thanks a lot. But could setting \beta = 0 and \gamma = 1 disable the effect of batchnorm? The input activations will still be normalized with its own mean and variance among the batch.
I’ve also reset the running estimates and set the layer to
eval, so that these running estimates will be used and no batch statistics will be calculated.
It’s not pretty as you might get errors if you forget to switch this layer to
eval, so that’s why I also suggested the other approach of removing the layers completely.
Ah I see. Thank a lot.