How do I stop updates to a part of BN's running_mean and running_var buffers?

I am training a model in which only a part of the BN’s weights and biases are updated. I can do this by setting the gradients of the ‘frozen’ part to zero, during the backward pass. However, the running_mean and running_var buffers still get ‘updated’ with every iteration. How do I stop this from happening?

Thanks.

Call .eval() on the batchnorm layer(s) to disable the updates of the running stats in each forward pass.

Pardon my lack of clarity in the original post. The situation I’m in is something like this. I have a network that contains a BN layer say:

bn = nn.BatchNorm2d(64)

So the layer above has 64 features.
bn.weight.shape # 64
Also, bn.running_mean.shape # 64

Now, while training I want the first 8 elements of the bn layer to not get updated and update the remaining 56 of them. I can easily achieve that on the bn.weight tensor by setting the gradients of the first 8 elements in the backward pass to 0 using torch.register_hook on the tensor generated by bn.weights.

But I can’t stop bn.running_mean’s first 8 elements from not being updated. Every time there is a forward call on the bn layer, all of bn.running_mean’s elements get updated.

What I want is to stop bn.running_mean from updating the first 8 elements and have the rest of the 56 elements update as usual. How do I achieve this?

Thanks.

I am currently the following layer implementation (that supports partial updates to running_mean and running_var through the partial_features variable.

class PartialBatchNorm2d(nn.BatchNorm2d):
    def __init__(
        self,
        num_features,
        eps=0.00001,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
        partial_features=0,
    ):
        super().__init__(
            num_features=num_features,
            eps=eps,
            momentum=momentum,
            affine=affine,
            track_running_stats=track_running_stats,
        )
        self.register_buffer('partial_features', torch.zeros((1), dtype=torch.int32))
        self.partial_features[0] = partial_features

    def forward(self, x):
        running_mean_copy = self.running_mean.data.clone()
        running_var_copy = self.running_var.data.clone()

        x = super().forward(x)

        if self.training and self.track_running_stats:
            self.running_mean.data[: self.partial_features] = running_mean_copy[: self.partial_features]
            self.running_var.data[: self.partial_features] = running_var_copy[: self.partial_features]

        return x