How to replace all ReLU activations in a pretrained network?

Greetings, I am trying to replace all ReLU’s in a DenseNet by another activation. Also, I would like to replace all the BatchNorm layers with GroupNorm layers. However, I have hard time accessing the modules of the DenseNet since it is a bunch of Sequential modules. Any advice would help, thanks.

1 Like

Hello.

It will depend on how the original model is implemented.
If it creates modules for the ReLU/batchnorm during the initialization, you can just replace these modules wherever they are and then the forward method will use your new modules instead.
If you use the functional interface for ReLU directly in the forward() method of the Module and do nn.functional.relu(). Then you will have to modify the forward() function itself to replace these calls with the activation function you want to use.

1 Like

Hello, thank you for your reply. I am using the pretrained version of the DenseNet from torchvision models. As it appears to me all the dense block attributes of the model are nn.Sequential and have BN and ReLU as modules in those sequential modules. Should I then rewrite the whole sequential module and replace the original one with that? The image may explain a little better what I am trying to say:

image

No you can just change the modules inplace.
If m is the top module, you should be able to do m.features[2] = NewActivation() to change the first relu called relu0 there.
Then you can do the same for all relus.

Be careful when changing the BatchNorm, They have some learnable parameters and some statistics. If you remove these, you might see a drop in performance if you’re using the trained version of the model.

Great, thank you! I will try replacing the activations as you suggested. In my case it looks like it will be more appropriate to get the untrained DenseNet and replace the activations and BNs and train on my dataset.

Is there a quicker way then listing out all the indexes you want to change?

This works:

model.features[2] = NewActivation()

But this is closer to what I want to do, however the assignment doesn’t work:

for i, (name, layer) in enumerate(model.named_modules()):
    if isinstance(layer, nn.ReLU):
        layer = NewActivation()
6 Likes

This is a small function I wrote that seems to work with a lot of the torchvision models. It tries to find all ReLU layers and convert them to Softplus:

def convert_relu_to_softplus(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.Softplus())
        else:
            convert_relu_to_softplus(child)
23 Likes

Hi,

is the solution by @cassidylaidlaw also works for a loaded model that had nonlinearity in its forward call, and not in the init and using nn.Sequential?

Thanks!

It won’t work for a model defined like so:

class MyModel(nn.Module):
    def forward(self, x):
        ...
        x = F.relu(x)
        ...

However, it would work for a model like this:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        ...
        self.relu = nn.ReLU()
        ...
    def forward(self, x):
        ...
        x = self.relu(x)
        ...
1 Like

hi @cassidylaidlaw cassidy,

I’m surprised that worked for you. I tried extending it for batch norm layers and it failed. See my example using a pre-trained resnet 18:

for name, module in model.named_modules():
    # if type(module) == torch.nn.BatchNorm2d:
    if 'bn' in name:
        print(name)
        print(module)
        new_module = eval('torch.nn.'+str(module).replace('track_running_stats=True', 'track_running_stats=False'))
        # module.load_state_dict(new_module.state_dict())
        setattr(model, name, new_module)

I get the error:

Traceback (most recent call last):
  File "/Users/brando/anaconda3/envs/myenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-78-4fd21dfb90a9>", line 23, in <module>
    for name, module in model.named_modules():
  File "/Users/brando/anaconda3/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1246, in named_modules
    for name, module in self._modules.items():
RuntimeError: OrderedDict mutated during iteration

@albanD do you know how to do this? I want to replace any arbitrary layer in a pre-trained net with one with my own custom layer OR the same layer with different hyperparameters.


I had already tried a variant of your code but it looks like it inserts the layers in a messy way in my use case. It seems I need to identify the nn.Sequential and other stuff to do this in a safe way…

model
Out[80]: 
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=5, bias=True)
  (layer1.0.bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer1.0.bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer1.1.bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer1.1.bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer2.0.bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer2.0.bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer2.1.bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer2.1.bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer3.0.bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer3.0.bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer3.1.bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer3.1.bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer4.0.bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer4.0.bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer4.1.bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (layer4.1.bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)

looks wrong to many attributes. Used code:

for name, module in copy.deepcopy(model).named_modules():
    # if type(module) == torch.nn.BatchNorm2d:
    if 'bn' in name:
        print(name)
        print(module)
        new_module = eval('torch.nn.'+str(module).replace('track_running_stats=True', 'track_running_stats=False'))
        # module.load_state_dict(new_module.state_dict())
        # model.__setattr__(name, module)
        setattr(model, name, new_module)
model.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)

The problem as the error states is just that you modify the dictionary as you go trough it and python does not allow that.
Just append the modifications that needs to be done in a list during the for loop and them apply all of them after you’re done iterating over the dictionary.

Got it working! I thought of your idea but didn’t see how to do it well. Perhaps you can share me your code to see how it should had been done. For now this is what works for me:

def replace_bn(module, name):
    '''
    Recursively put desired batch norm in nn.module module.

    set module = net to start code.
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                          track_running_stats=False)
            setattr(module, attr_str, new_bn)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn(immediate_child_module, name)

replace_bn(model, 'model')

source post: How to modify a pretrained model

Thanks AlbanD! :slight_smile:

Thanks @cassidylaidlaw man!