How to split a model, add a new layer and combine all the models

In one of my use cases, I need to split trained models and add a custom layer in between to perform some calculations.

I have tried as follows

vgg_model = models.vgg11(pretrained=True)

class CustomLayer(nn.Module):
    def __init__(self): 
        super().__init__()

        
    def forward(self, input_features):
        input_features = input_features*0.5 # some calculations
        return input_features  


class NetworkWithCustomeLayer(nn.Module):
    def __init__(self, trained_model, layer_number):
        super(NetworkWithCustomeLayer, self).__init__()
        self.initial_layers =  nn.Sequential(*list(trained_model.children())[:layer_number])
        self.custom_layer = CustomLayer()
        self.last_layers =  nn.Sequential(*list(trained_model.children())[layer_number:])
    def forward(self, input):
        print("input shape",input.shape)
        x_1 = self.initial_layers(input)
        print("x1 shape",x_1.shape)
        x_c = self.custom_layer(x_1)
        print("xc shape",x_c.shape)
        x_2 = self.last_layers(x_c)
        print("x2 shape",x_2.shape)
        return x_2
modified_vgg = NetworkWithCustomeLayer(vgg_model, 1)
modified_vgg.to(device)

this is how modified_vgg looks

NetworkWithCustomeLayer(
  (initial_layers): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (12): ReLU(inplace=True)
      (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (14): ReLU(inplace=True)
      (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (17): ReLU(inplace=True)
      (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (19): ReLU(inplace=True)
      (20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  )
  (custom_layer): CustomLayer()
  (last_layers): Sequential(
    (0): AdaptiveAvgPool2d(output_size=(7, 7))
    (1): Sequential(
      (0): Linear(in_features=25088, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=4096, out_features=4096, bias=True)
      (4): ReLU(inplace=True)
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=4096, out_features=1000, bias=True)
    )
  )
)

but when evaluting the model, I am getting following error:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (114688x7 and 25088x4096)

below is the stack trace of the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_24462/3433225722.py in <module>
     13     print('Accuracy of the network on test images: %d %%' % (
     14         100 * correct / total))
---> 15 evaluate_model(modified_vgg)

/tmp/ipykernel_24462/3433225722.py in evaluate_model(model)
      6             images = images.to(device)
      7             labels = labels.to(device)
----> 8             output=  model(images)
      9             _, predicted = torch.max(output, 1)
     10             total += labels.size(0)

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_24462/645674472.py in forward(self, input)
     23         x_c = self.custom_layer(x_1)
     24         print("xc shape",x_c.shape)
---> 25         x_2 = self.last_layers(x_c)
     26         print("x2 shape",x_2.shape)
     27         return x_2

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    137     def forward(self, input):
    138         for module in self:
--> 139             input = module(input)
    140         return input
    141 

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    137     def forward(self, input):
    138         for module in self:
--> 139             input = module(input)
    140         return input
    141 

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
     94 
     95     def forward(self, input: Tensor) -> Tensor:
---> 96         return F.linear(input, self.weight, self.bias)
     97 
     98     def extra_repr(self) -> str:

/anaconda/envs/py38_pytorch/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1845     if has_torch_function_variadic(input, weight):
   1846         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1847     return torch._C._nn.linear(input, weight, bias)
   1848 
   1849 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (114688x7 and 25088x4096)

when we see the output of the custom layer its same as the inital_layers, but getting shape mismatch error for final layers. dont know why the data shape is getting changed

Rewrapping the modules in an nn.Sequential block can easily break, since you would miss all functional API calls from the original forward method and will thus only work if the layers are initialized and executed sequentially.
For VGG11 you would be missing the torch.flatten operation from here, which would create the shape mismatch.
The safest way would be to create a custom model, add the custom layer, and override the forward.
Alternatively, you could also replace an already used layer (e.g. the last nn.MaxPool2d layer) with an nn.Sequential container containing itself and the new custom layer.

1 Like

Thanks @ptrblck.
separating fc layer and adding flatten before it helped.

class CustomLayer(nn.Module):
    def __init__(self): 
        super().__init__()

        
    def forward(self, input_features):
        input_features = input_features*0.5 # some calculations
        return input_features  


class NetworkWithCustomeLayer(nn.Module):
    def __init__(self, trained_model, layer_number):
        super(NetworkWithCustomeLayer, self).__init__()
        self.initial_layers =  nn.Sequential(*list(trained_model.children())[:layer_number])
        self.custom_layer = CustomLayer()
        self.last_layers =  nn.Sequential(*list(trained_model.children())[layer_number:-1])
        self.fc = list(vgg_model.children())[-1]
    def forward(self, input):
        print("input shape",input.shape)
        x_1 = self.initial_layers(input)
        print("x1 shape",x_1.shape)
        x_c = self.custom_layer(x_1)
        print("xc shape",x_c.shape)
        x_2 = self.last_layers(x_c)
        print("x2 shape",x_2.shape)
        x_3 = torch.flatten(x_2, 1)
        print("x3 shape",x_3.shape)
        x_4 = self.fc(x_3) 
        return x_4
modified_vgg = NetworkWithCustomeLayer(vgg_model, 1)
modified_vgg.to(device)