I am training a model in which only a part of the BN’s weights and biases are updated. I can do this by setting the gradients of the ‘frozen’ part to zero, during the backward pass. However, the running_mean and running_var buffers still get ‘updated’ with every iteration. How do I stop this from happening?
Pardon my lack of clarity in the original post. The situation I’m in is something like this. I have a network that contains a BN layer say:
bn = nn.BatchNorm2d(64)
So the layer above has 64 features. bn.weight.shape # 64
Also, bn.running_mean.shape # 64
Now, while training I want the first 8 elements of the bn layer to not get updated and update the remaining 56 of them. I can easily achieve that on the bn.weight tensor by setting the gradients of the first 8 elements in the backward pass to 0 using torch.register_hook on the tensor generated by bn.weights.
But I can’t stop bn.running_mean’s first 8 elements from not being updated. Every time there is a forward call on the bn layer, all of bn.running_mean’s elements get updated.
What I want is to stop bn.running_mean from updating the first 8 elements and have the rest of the 56 elements update as usual. How do I achieve this?
I am currently the following layer implementation (that supports partial updates to running_mean and running_var through the partial_features variable.