Dynamically insert new layer at middle of a pre-trained model

I would like to insert a new layer dynamically into an existing (pre-trained) model. By »dynamically« I mean I would like to have a function foo which takes

  • an existing model
  • a coordinate (e.g. an index or a label)
  • a layer to insert

and returns the existing model (which maintains its weights) including the inserted layer at the provided coordinate. (These requirements on foo are not compulsory. It would be totally fine if the solution would be some kind of mapping function or whatever – as long as it enables me to focus on the changing part of the model and it takes care of handling the remaining parts of the model).

Take AlexNet for example:

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

In this example I try to insert a torch.nn.Identity layer after the ReLU layer (7) within the (features) group. What I could do (if I understand this thread correctly) is this:

model = models.alexnet(pretrained = True)
feats = list(model.features.children())
feats.insert(8, nn.Identity())
model.features = nn.Sequential(feats)

The result is then:

AlexNet(
  (features): Sequential(
    ...
    (7): ReLU(inplace=True)
    (8): Identity()
    (9): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    ...
  )
)

However, this hardcoded approach does not scale. In the end I would like to insert non-trivial layers dynamically into a several models at several positions.

How do I achieve this?

2 Likes

My opinion would be initialize the model then copy the weights back except the new layers. By this way you can be pretty sure all the things are connected in the network.

You can refer to any dynamically expanding network implementation for this.

I write a small example, you can check if the weights of features[6] are also copied from model.feature[6] to model.feature[6][0].

import re
import torch
import torch.nn as nn
from torchvision import models


def insert_module(model, indices, modules):
    indices = indices if isinstance(indices, list) else [indices]
    modules = modules if isinstance(modules, list) else [modules]
    assert len(indices) == len(modules)

    layers_name = [name for name, _ in model.named_modules()][1:]
    for index, module in zip(indices, modules):
        layer_name = re.sub(r'(.)(\d)', r'[\2]', layers_name[index])
        exec("model.{name} = nn.Sequential(model.{name}, module)".format(name = layer_name))

model = models.alexnet(pretrained = True)
print(model)
insert_module(model, [7, 9], [nn.Identity(), nn.ReLU(inplace=True)])
print(model)

PS: If the module nn.Identity() you inserted is not forwarded sequentially after model.features[7], then there is no need to consider the insertion location, so here it is assumed that you want to forward nn.Identity() immediately after forwarding model.features[7].