Need help---can I modify a batchnorm2d's num_features?

How can I just modify a batchnorm2d’s num_features?
For example, how to create a new batchnorm2d’s object with new num_features and old other properties?
Thank You!

Are you looking for something like this:

bn1 = nn.BatchNorm2d(num_features=3)
bn_new = nn.BatchNorm2d(
    num_features=4,
    eps=bn1.eps,
    momentum=bn1.momentum,
    affine=bn1.affine,
    track_running_stats=bn1.track_running_stats)

Yes! Will this work-----keep other parameters invariable(except num_features)?

What other parameters do you mean?
Changing the num_features will change the weight, bias and running stats.
Could you explain your use case? Maybe I’m misunderstanding it.

If the batchnorm layer is next to a con2d layer(out_channel is 64). So the batchnorm is nn.BatchNorm2d(num_features=64).
If the previous layer’s 10th filter was pruned(The output feature maps’ number cut 1).
Now the the batchnorm should be modified to nn.BatchNorm2d(num_features=63).
The weight, bias and running_mean in batchnorm layer are useful. How could I keep these parameters in new batchnorm layer? new_param[0:10] = old_param[0:10]; new_param[10:] = old_param[11:] ,like this.
Thank you!

Hi, if I want to keep the weight, bias, runingstatus of old BatchNorm layer.
What should I do?
The details had been stated in above reply

This should generally work.
Here is a small code sample showing that the model outputs the right shape after pruning:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(64)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        return x


model = MyModel()
x = torch.randn(1, 1, 10, 10)
output = model(x)
output.sum().backward()
print(model.conv1.weight.grad)
model.zero_grad()

# Prune channel10
with torch.no_grad():
    model.conv1.weight = nn.Parameter(torch.cat((
        model.conv1.weight[:10], model.conv1.weight[11:])))
    model.conv1.bias = nn.Parameter(torch.cat((
        model.conv1.bias[:10], model.conv1.bias[11:])))
    model.bn1.weight = nn.Parameter(torch.cat((
        model.bn1.weight[:10], model.bn1.weight[11:])))
    model.bn1.bias = nn.Parameter(torch.cat((
        model.bn1.bias[:10], model.bn1.bias[11:])))
    model.bn1.running_mean = torch.cat((
        model.bn1.running_mean[:10], model.bn1.running_mean[11:]))
    model.bn1.running_var = torch.cat((
        model.bn1.running_var[:10], model.bn1.running_var[11:]))
    
output = model(x)
print(output.shape)
output.sum().backward()
print(model.conv1.weight.grad)
model.zero_grad()