I am trying to replace a certain type of layer (e.g. nn.Conv2d) with my own custom variant in a way that I do not have to copy paste the model definition code and replace the relevant lines.
I tried to do this using the solution mentioned in: Replacing convs modules with custom convs, then NotImplementedError
However when using this with the resnet model from torchvision and trying to replace the nn.Conv2d layers it does not replace the conv2d layers in the downsample block, and when trying to use it with the alexnet model from torchvision it does not replace any conv2d layer.
What is the recommended way to replace a certain type of layer with another in such a way that the original code definition of the model does not need to be copy pasted and adapted?
Layers can only be replaced when the input and output dimensions are same.
And unless I am mistaken the layers in ResNet are present as layer blocks not individually. So either you must replace an entire block of them.
Anyway I think it’s better to replace the nodes near the end as they are more likely to carry less information in their weights. And the dense layers at they end. Those shouldn’t give any trouble.
Just be careful with dimensions and you should be fine I think.
Take a preview of the model by viewing it’s repr before trying anything.
@sansmoriaxz I think you didnt fully understand my question I’m talking about replacing a layer type with another type for example, replacing torch.nn.Conv2d(*args, **kwargs) with mypackage.CustomConv2d(*args, **kwargs)
@albanD you replied in a similar thread How to replace all ReLU activations in a pretrained network? that it can be done inplace, but it is not clear to me exactly how that should be done.
There is no general guideline as it depends a lot on how the network you want to modify is structured as nn.Modules.
If you have your convs as
self.conv2 etc, then you need to change these.
If they are in a Sequential, you can find them and replace the
self.modules[conv_idx] value for each.
If it’s in the model definition in your python file, you can use another function like:
def new_conv(*args, **kwargs):
return MyConv(*args, **kwargs)
# In your net definition
self.conv = new_conv(foo, bar)
An easy way that might work in some cases is creating a module that contains a conv layer but also does something before or after the conv layer has done its forward pass.
you can use something like this for replacing layers.
The picture above will replace all MaxPool2d layers with AvgPool2d
I think this will work for you, just change it to your custom layer. Let us know if did work:
def replace_bn(module, name):
Recursively put desired batch norm in nn.module module.
set module = net to start code.
# go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
for attr_str in dir(module):
target_attr = getattr(m, attr_str)
if type(target_attr) == torch.nn.BatchNorm2d:
print('replaced: ', name, attr_str)
new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
setattr(module, attr_str, new_bn)
# iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
for name, immediate_child_module in module.named_children():
original post: How to modify a pretrained model
I tried this for resnet to change BN into GN. but it keeps 3 BN unchanged.
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): GroupNorm(1, 64, eps=1e-05, affine=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): GroupNorm(1, 64, eps=1e-05, affine=False)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): GroupNorm(1, 256, eps=1e-05, affine=False)
(relu): ELU(alpha=1.0, inplace=True)
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Nice tutorial! It is a very informative tutorial. It was extremely helpful.
I get the following error when I try this approach:
TypeError: 'VGG' object is not iterable
Any idea why this is the case?
I get an error since
m is not defined. Any idea what I do wrong? Where does
m come from?
Sometimes when the layers are in a list, then it creates a problem with getattr as it does not understand something like ‘layer1.0.bn1’ which you would otherwise address as net.layer1.bn1. To resolve this I have created the following script.
def get_all_parent_layers(net, type):
layers = 
for name, l in net.named_modules():
if isinstance(l, type):
tokens = name.strip().split('.')
layer = net
for t in tokens[:-1]:
if not t.isnumeric():
layer = getattr(layer, t)
layer = layer[int(t)]
for parent_layer, last_token in get_all_parent_layers(net, nn.BatchNorm2d):
setattr(parent_layer, last_token, nn.Identity())
First we return all the parent layers and the last tokens of each parent layer that satisfies a particular layer type such as nn.BatchNorm2d. Then we can simply iterate over the list of all (parent_layer, last_token) and use setattr to perform our desired change.