Delete last layers of a pertained model and connect to an MLP

Hello, I’m trying to delete the last 3 layers of a pretrained model, and average pool the current last layer (i.e -4th layer), then connect this to an MLP layer. So I loaded the weights and froze the weights, next I deleted the last layers and appended it to an MLP(essentially, I want just the MLP to get trained), but I keep getting an error.
structure of pretrained model(prior to deleting last 3 layers)

cLAPIRN(
  (transform): SpatialTransform_unit()
  (input_encoder_lvl1): Sequential(
    (0): Conv2d(2, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (down_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (resblock_group_lvl1): ModuleList(
    (0): PreActBlock_Conditional(
      (ai1): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (ai2): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (mapping): Sequential(
        (0): Linear(in_features=1, out_features=64, bias=True)
        (1): LeakyReLU(negative_slope=0.2)
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): LeakyReLU(negative_slope=0.2)
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): LeakyReLU(negative_slope=0.2)
        (6): Linear(in_features=64, out_features=64, bias=True)
        (7): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): LeakyReLU(negative_slope=0.2)
    (2): PreActBlock_Conditional(
      (ai1): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (ai2): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (mapping): Sequential(
        (0): Linear(in_features=1, out_features=64, bias=True)
        (1): LeakyReLU(negative_slope=0.2)
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): LeakyReLU(negative_slope=0.2)
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): LeakyReLU(negative_slope=0.2)
        (6): Linear(in_features=64, out_features=64, bias=True)
        (7): LeakyReLU(negative_slope=0.2)
      )
    )
    (3): LeakyReLU(negative_slope=0.2)
    (4): PreActBlock_Conditional(
      (ai1): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (ai2): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (mapping): Sequential(
        (0): Linear(in_features=1, out_features=64, bias=True)
        (1): LeakyReLU(negative_slope=0.2)
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): LeakyReLU(negative_slope=0.2)
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): LeakyReLU(negative_slope=0.2)
        (6): Linear(in_features=64, out_features=64, bias=True)
        (7): LeakyReLU(negative_slope=0.2)
      )
    )
    (5): LeakyReLU(negative_slope=0.2)
    (6): PreActBlock_Conditional(
      (ai1): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (ai2): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (mapping): Sequential(
        (0): Linear(in_features=1, out_features=64, bias=True)
        (1): LeakyReLU(negative_slope=0.2)
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): LeakyReLU(negative_slope=0.2)
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): LeakyReLU(negative_slope=0.2)
        (6): Linear(in_features=64, out_features=64, bias=True)
        (7): LeakyReLU(negative_slope=0.2)
      )
    )
    (7): LeakyReLU(negative_slope=0.2)
    (8): PreActBlock_Conditional(
      (ai1): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (ai2): ConditionalInstanceNorm(
        (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (style): Linear(in_features=64, out_features=512, bias=True)
      )
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (mapping): Sequential(
        (0): Linear(in_features=1, out_features=64, bias=True)
        (1): LeakyReLU(negative_slope=0.2)
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): LeakyReLU(negative_slope=0.2)
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): LeakyReLU(negative_slope=0.2)
        (6): Linear(in_features=64, out_features=64, bias=True)
        (7): LeakyReLU(negative_slope=0.2)
      )
    )
    (9): LeakyReLU(negative_slope=0.2)
  )
  (up): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2), bias=False)
  (down_avg): AvgPool2d(kernel_size=3, stride=2, padding=1)
  (output_lvl1): Sequential(
    (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(256, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): Softsign()
  )
)

structure of concatenated model

MLP(
  (pretrained): Sequential(
    (0): SpatialTransform_unit()
    (1): Sequential(
      (0): Conv2d(2, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (3): ModuleList(
      (0): PreActBlock_Conditional(
        (ai1): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (ai2): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (mapping): Sequential(
          (0): Linear(in_features=1, out_features=64, bias=True)
          (1): LeakyReLU(negative_slope=0.2)
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): LeakyReLU(negative_slope=0.2)
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): LeakyReLU(negative_slope=0.2)
          (6): Linear(in_features=64, out_features=64, bias=True)
          (7): LeakyReLU(negative_slope=0.2)
        )
      )
      (1): LeakyReLU(negative_slope=0.2)
      (2): PreActBlock_Conditional(
        (ai1): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (ai2): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (mapping): Sequential(
          (0): Linear(in_features=1, out_features=64, bias=True)
          (1): LeakyReLU(negative_slope=0.2)
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): LeakyReLU(negative_slope=0.2)
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): LeakyReLU(negative_slope=0.2)
          (6): Linear(in_features=64, out_features=64, bias=True)
          (7): LeakyReLU(negative_slope=0.2)
        )
      )
      (3): LeakyReLU(negative_slope=0.2)
      (4): PreActBlock_Conditional(
        (ai1): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (ai2): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (mapping): Sequential(
          (0): Linear(in_features=1, out_features=64, bias=True)
          (1): LeakyReLU(negative_slope=0.2)
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): LeakyReLU(negative_slope=0.2)
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): LeakyReLU(negative_slope=0.2)
          (6): Linear(in_features=64, out_features=64, bias=True)
          (7): LeakyReLU(negative_slope=0.2)
        )
      )
      (5): LeakyReLU(negative_slope=0.2)
      (6): PreActBlock_Conditional(
        (ai1): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (ai2): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (mapping): Sequential(
          (0): Linear(in_features=1, out_features=64, bias=True)
          (1): LeakyReLU(negative_slope=0.2)
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): LeakyReLU(negative_slope=0.2)
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): LeakyReLU(negative_slope=0.2)
          (6): Linear(in_features=64, out_features=64, bias=True)
          (7): LeakyReLU(negative_slope=0.2)
        )
      )
      (7): LeakyReLU(negative_slope=0.2)
      (8): PreActBlock_Conditional(
        (ai1): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (ai2): ConditionalInstanceNorm(
          (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (style): Linear(in_features=64, out_features=512, bias=True)
        )
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (mapping): Sequential(
          (0): Linear(in_features=1, out_features=64, bias=True)
          (1): LeakyReLU(negative_slope=0.2)
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): LeakyReLU(negative_slope=0.2)
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): LeakyReLU(negative_slope=0.2)
          (6): Linear(in_features=64, out_features=64, bias=True)
          (7): LeakyReLU(negative_slope=0.2)
        )
      )
      (9): LeakyReLU(negative_slope=0.2)
    )
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): Linear(in_features=30, out_features=64, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): ReLU()
    )
    (3): Sequential(
      (0): Linear(in_features=32, out_features=8, bias=True)
      (1): ReLU()
    )
  )
  (out_layer): Sequential(
    (0): Linear(in_features=16, out_features=1, bias=True)
    (1): ReLU()
  )
)


This is how I concatenated them:

model_lvl1 = cLAPIRN(2, 2, start_channel, is_train=True, imgshape=imgshape_4, range_flow=0.4)

model_path = "path"
    model_lvl1.load_state_dict(torch.load(model_path))


pretrained = torch.nn.Sequential(*(list(model_lvl1.children()))[:-3]) #remove last 3 layers

    for params in pretrained.parameters(): #freeze weight
        params.requires_grad = False

    mlp_model = MLP(pretrained) #create MLP model and pass pretrained model in init function
    print(mlp_model)
    output = mlp_model(X, Y, reg_code)
    print(output)

This is the MLP class

class MLP(nn.Module):
def __init__(self, pretrained):
        super(MLP, self).__init__()
        self.pretrained = pretrained
        self.layers = [30, 64, 64, 32, 8]
        blocks = [fc_builder(input_dims, output_dims) for input_dims, output_dims in zip(self.layers, self.layers[1:])]
        self.blocks = nn.Sequential(*blocks)
        self.out_layer = nn.Sequential(
                                nn.Linear(16, 1), #check with prediction output
                                nn.ReLU()
                            )

 def forward(self, x, y, reg_code):
        
        x = self.pretrained(x, y, reg_code)
        x = self.blocks(x)
        output = self.out_layer(x)
        return output

def fc_builder(input_dim, output_dims):
    return nn.Sequential(
        nn.Linear(input_dim, output_dims),
        nn.ReLU()
    )

This is the error I’m getting:

 return forward_call(*input, **kwargs)
TypeError: Sequential.forward() takes 2 positional arguments but 4 were given

And I dont have any “sequential.forward()” method.
Is it possible to directly delete the last 3 layers from the first model without having to pass it through sequential? and directly just add the MLP to it?

Does your pre-trained model take 3 inputs?
Comparing the given error message and the code, “pre-trained model” not accepting 3 inputs could be a probable cause.

YEs it does, this is the forward method of my “self.pretrained”

def forward(self, x, y, reg_code):

        input = torch.cat((x, y), 1)
        input = self.avg(input) 
        input = self.down(input)

I just did this to delete the last 2 layers.

for i in range(2):
        model_lvl1._modules.popitem()