How to close BatchNorm when using torchvision models

I am trying to fine tune the pretrained resnet model provided by torchvision, and I need to remove serveral batchnorm layers before implementing the fine-tune. How could I close these BN layers?

1 Like

This code should deactivate the nn.BatchNorm layers:

def deactivate_batchnorm(m):
    if isinstance(m, nn.BatchNorm2d):
        m.reset_parameters()
        m.eval()
        with torch.no_grad():
            m.weight.fill_(1.0)
            m.bias.zero_()


model = nn.Sequential(
    nn.Conv2d(3, 6, 3, 1, 1),
    nn.BatchNorm2d(6)
)

x = torch.randn(10, 3, 24, 24)
output1 = model[0](x)
output2 = model(x)
print(torch.allclose(output1, output2))
> False

model.apply(deactivate_batchnorm)
output1 = model[0](x)
output2 = model(x)
torch.allclose(output1, output2)
> True

Keep in mind, that you would need to call this function again after switching the whole model between model.train() and model.eval().

Another approach would be to completely remove the batch norm layers and recreate the model, which might be quite complicated based on your model and forward.

4 Likes

Thanks a lot. But could setting \beta = 0 and \gamma = 1 disable the effect of batchnorm? The input activations will still be normalized with its own mean and variance among the batch.

I’ve also reset the running estimates and set the layer to eval, so that these running estimates will be used and no batch statistics will be calculated.
It’s not pretty as you might get errors if you forget to switch this layer to eval, so that’s why I also suggested the other approach of removing the layers completely.

Ah I see. Thank a lot.