Forward() #argument error

I want to use ConvNeXt as my backbone and make some changes. So I use nn.Sequential to separate the model into many parts.

from timm.models import create_model
import timm

import torch
import torch.nn as nn
import torchvision.utils

import imp

from networks.convnext import *
import networks.convnext as convnext

if __name__ == "__main__":

   model_name = 'convnext'
   num_classes = 1024
   
   weights_path = './pretrained_weights/convnext/convnext_large_22k_1k_224.pth'

   net_type = 'convnext_large'
   
   basenet = convnext.convnext_large(pretrained=False,num_classes=num_classes)

   x = torch.ones(1,3,224,224).cuda()

   conv2 = nn.Sequential(*list(basenet.children())[:-3])

   x = conv2(x)

   conv3 = nn.Sequential(*list(basenet.children())[-3][:-2])
   conv4 = nn.Sequential(*list(basenet.children())[-3][-2])
   conv5 = nn.Sequential(*list(basenet.children())[-3][-1])
   layer_norm = list(basenet.children())[-2]
   head = list(basenet.children())[-1]

However, when I run this script. It gives an error which says

Traceback (most recent call last):
  File "test_load_model.py", line 28, in <module>
    x = conv2(x)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 1 positional argument but 2 were given

ConvNeXt codes can be found from the official implementation
https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py

Anyone knows how to solve this error? Any help is appreciated´╝ü

I cannot reproduce the issue and get another error:

basenet = ConvNeXt()
conv2 = nn.Sequential(*list(basenet.children())[:-3])
x = torch.randn(2, 3, 224, 224)

x = conv2(x)
# > NotImplementedError

since conv2 contains nn.ModuleLists instead of nn.Sequential containers:

Sequential(
  (0): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm()
    )
    (1): Sequential(
      (0): LayerNorm()
      (1): Conv2d(96, 192, kernel_size=(2, 2), stride=(2, 2))
    )
    (2): Sequential(
      (0): LayerNorm()
      (1): Conv2d(192, 384, kernel_size=(2, 2), stride=(2, 2))
    )
    (3): Sequential(
      (0): LayerNorm()
      (1): Conv2d(384, 768, kernel_size=(2, 2), stride=(2, 2))
    )
  )
)

You can iterate the nn.ModuleList to get all modules and call theim directly e.g. as seen here.