How can I just modify a batchnorm2d’s num_features?
For example, how to create a new batchnorm2d’s object with new num_features and old other properties?
Thank You!
Are you looking for something like this:
bn1 = nn.BatchNorm2d(num_features=3)
bn_new = nn.BatchNorm2d(
num_features=4,
eps=bn1.eps,
momentum=bn1.momentum,
affine=bn1.affine,
track_running_stats=bn1.track_running_stats)
Yes! Will this work-----keep other parameters invariable(except num_features)?
What other parameters do you mean?
Changing the num_features
will change the weight, bias and running stats.
Could you explain your use case? Maybe I’m misunderstanding it.
If the batchnorm layer is next to a con2d layer(out_channel is 64). So the batchnorm is nn.BatchNorm2d(num_features=64).
If the previous layer’s 10th filter was pruned(The output feature maps’ number cut 1).
Now the the batchnorm should be modified to nn.BatchNorm2d(num_features=63).
The weight, bias and running_mean in batchnorm layer are useful. How could I keep these parameters in new batchnorm layer? new_param[0:10] = old_param[0:10]; new_param[10:] = old_param[11:] ,like this.
Thank you!
Hi, if I want to keep the weight, bias, runingstatus of old BatchNorm layer.
What should I do?
The details had been stated in above reply
This should generally work.
Here is a small code sample showing that the model outputs the right shape after pruning:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(64)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
return x
model = MyModel()
x = torch.randn(1, 1, 10, 10)
output = model(x)
output.sum().backward()
print(model.conv1.weight.grad)
model.zero_grad()
# Prune channel10
with torch.no_grad():
model.conv1.weight = nn.Parameter(torch.cat((
model.conv1.weight[:10], model.conv1.weight[11:])))
model.conv1.bias = nn.Parameter(torch.cat((
model.conv1.bias[:10], model.conv1.bias[11:])))
model.bn1.weight = nn.Parameter(torch.cat((
model.bn1.weight[:10], model.bn1.weight[11:])))
model.bn1.bias = nn.Parameter(torch.cat((
model.bn1.bias[:10], model.bn1.bias[11:])))
model.bn1.running_mean = torch.cat((
model.bn1.running_mean[:10], model.bn1.running_mean[11:]))
model.bn1.running_var = torch.cat((
model.bn1.running_var[:10], model.bn1.running_var[11:]))
output = model(x)
print(output.shape)
output.sum().backward()
print(model.conv1.weight.grad)
model.zero_grad()