Insert New Layer in the Middle of a Pre-trained Model

Help!!! Does anyone knows how to insert a new layer in the middel of a pre-trained model? e.g insert a new conv in the middle of Resnet’s bottelneck.

2 Likes

There’s no easy way to insert a new layer in the middle of an existing model as far as I’m aware. A definite solution is to build the structure that you want in a new class and then copy the corresponding weights over from the pretrained model.

1 Like

Figured!!!
After loading model,we can directly specify model.conv_x = nn.Sequential([new_layer, model.conv_x]), by this way, we can still use thepretrained model.conv_x

11 Likes

@wangchust
Can you help a newbie like me. I import pretrained resnest34 as:

resnet = models.resnet34(pretrained=True)

Now I want to insert a conv2d 1x1 kernel layer before the fc to increase channel size from 512 to 6000 and then add a fc 6000 x 6000

I am so new to pytorch that I need some hand-holding. Can you write the lines of code needed? I am still at monkey-see monkey-learn stage. Thanks in anticipation

2 Likes

I think the easiest approach would be to derive from ResNet and add your layers.
This should do what you need:


class MyResnet2(models.ResNet):
    def __init__(self, block, layers, num_classes=1000):
        super(MyResnet2, self).__init__(block, layers, num_classes)
        self.conv_feat = nn.Conv2d(in_channels=512, 
                                   out_channels=6000, 
                                   kernel_size=1)
        self.fc = nn.Linear(in_features=6000,
                            out_features=6000)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.conv_feat(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

from torchvision.models.resnet import BasicBlock
model = MyResnet2(BasicBlock, [3, 4, 6, 3], 1000)
x = Variable(torch.randn(1, 3, 224, 224))
output = model(x)

Note that the shape of x in x= self.avgpool(x) is already [batch, 512, 1, 1] for x.shape = [batch, 3 ,224, 224].
You could therefore flatten x and just use two Linear layers, since it would be the same as a Conv2d with kernel_size=1.

10 Likes

@ptrblck
Thank you sir!,
Now if I re-read some of the tutorials, they will register in my head.

1 Like

Sorry for being so late to reply. My way should be replace ‘Resnet.fc = nn.Linear(512, num_classes)’ with ‘Resnet.fc = nn.Sequential(nn.Conv2d(512, 6000), nn.Linear(6000,6000))’

1 Like

Using this approach you would have to define a Flatten layer, since it will crash in the Sequential model.

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return x
        

seq = nn.Sequential(nn.Conv2d(512, 6000, 1),
                    Flatten(),
                    nn.Linear(6000,6000))

x = Variable(torch.randn(1, 512, 1, 1))
seq(x)
1 Like

@ptrblck and @wangchust,

Thank you both.

I find this very useful to create a custom model which inserts a new layer in middle of ResNet, but how do we get the pretrained weights? This is giving me a new neural net which doesn’t have the pre trained weights.
Thank you.

I want to add attention layers in the pretrained resnet model how can I do so after every resnet block in the model.

This is what I did, I think it works.

vgg11.features[0] = nn.Sequential(inserted_layer, vgg11.features[0])

I checked both parameters, and I think the pretrained parameters are preserved.

I first divided the model into two parts; 1st half are the layers before the layer you want to add and the 2nd half contains the layers after your layer. So something like this:

encoder = nn.Sequential(*list(m.children())[:8])
decoder  = nn.Sequential(*list(m.children())[8:])

Then I added the layers I need, get a list of the layers and append each layer in the 2nd layer into the 1st one. Like this:

tmp_1 = list(encoder.children())
tmp_2 = list(decoder.children())
for i in tmp_2:
    tmp_1.append(i)

And then model = nn.Sequential(*list(tmp_1))

2 Likes

Sir, In this, I would like to use the pre-trained weights up to layer4. And I have to do transfer learning for the remaining layers.

My requirement is
// considering pretrained model
from torchvision import models
resnet50 = models.resnet50(pretrained = True)

//pruning the layers
prune.l1_unstructured(resnet50.conv1, name=‘weight’, amount=0.4)

// adding layer in the middle (after layer 4)
class MyResNet50(models.ResNet):

def __init__(self, my_pretrained_model):

    super(MyResNet50, self).__init__(my_pretrained_model)
    self.pretrained_model = my_pretrained_model
    self.conv_feat = nn.Conv2d(in_channels=512, 
                               out_channels=512, 
                               kernel_size=3, padding=1)
    self.my_new_layers = nn.Sequential(nn.Linear(1000, 512),nn.Linear(512,1))
    
def forward(self, x):

    x = self.pretrained_model(x)
    x = self.conv_feat(x)
    x = self.avgpool(x)
    x = self.my_new_layers(x)

    return x

//creating object for modified resnet50
myextendedmodel = MyResNet50(resnet50)

sir, is it possible to get modified resnet50?

Would this mean the overall model architecture stays the same as in the original ResNet while you are only training some layers (starting from layer5)? If so, you wouldn’t have to add new layers but would need to freeze the initial 4 layers by setting the .requires_grad attribute of the corresponding parameters to False.

Thank you, sir for the solution.

Hi ptrblck

I have a question regarding loading a pretrained model,
I have a pretrained model with some convolution layers (2d) that bias is disabled (bias flag in nn.Conv2d() is set to False), I was wondering if there is way to set the bias for this conv layer to true and initialize it the way we want?
Thanks in advance

You could use the strict=False argument in load_state_dict as seen here:

ref = nn.Conv2d(3, 3, 3, bias=False)
conv = nn.Conv2d(3, 3, 3, bias=True)

conv.load_state_dict(ref.state_dict())
# RuntimeError: Error(s) in loading state_dict for Conv2d:
# 	Missing key(s) in state_dict: "bias". 

conv.load_state_dict(ref.state_dict(), strict=False)
# _IncompatibleKeys(missing_keys=['bias'], unexpected_keys=[])

but make sure only the expected keys are missing by checking the returned object.

Thank you so much
I use the following procedure and it seems is working:

 for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            if layer.bias is None:
                layer.bias = nn.Parameter(data=torch.Tensor(torch.randn(size=(layer.weight.shape[0],))/1e+10))

but your way is more clean.