Freezing Batch norm layers while keeping them in the graph

If we set requires_grad to False for batch norm layers of a model, the batch norm layers do not remain in the graph. In this case, I cant fine tune these layers later if I want to. Is there any way we can freeze the layers, yet keep them in the graph so that they can be trained later?

You could set the requires_grad attributes of the batchnorm parameters to True again if you want to train them later.

Thank you very much. I tried this. Setting them to requires_grad = False in the initial training causes the batch norm layers to not train and they later can be trained for finetuning with requires_grad = True.
However, I noticed that even if the batch norm layers are not being trained with requires_grad = False, I think all the BN gradient calculations are still happening and a lot of memory is being occupied. In fact, I cant train a complete model because I am getting out of memory error very often. I am guessing if there is a way to shut down the mean and variance calculations.
BTW, I was using the SyncBatchNorm class because I am using 2 GPUs for training with DataParallel.

The gradients might still be needed if previous layers depend on them.
To disable the stats updates, call .eval() on the batchnorm layers or use track_running_stats=False if you want to use the batch stats during training and evaluation.

1 Like

Thanks. I will try this asap. Meanwhile, I was wondering if there is a way to set the batch stats - mean and variance to 0s and 1s so that all gradients are based on these values and we altogether avoid calculating the means and variances(which might be taking up memory)?
Also, I was following your manual implementation of the batch normalization layer(pytorch_misc/ at master · ptrblck/pytorch_misc · GitHub). Could you tell me if it supports multi GPUs?

The stats will be initialized with these values, so you could call .eval() directly on the batchnorm layers after initializing the model.
However, note that freezing the affine parameters and disabling the normalization via .eval() would mean that you are basically adding “noise” to the model and I’m unsure if that’s your use case or why you would want to do it.

It might work with e.g. DDP, but note that I’ve written it more for educational reasons to play around with the manual implementation in Python directly.