Slicing sequential modules

Hi ;

I am using a pretrained resnet 18. I would like to truncate the model to only have the first two convolution. I tried the following code, however I find that there is block called Basic block has been defined. Any ideas please

model = torchvision.models.resnet18(pretrained=True)

truncated_model = nn.Sequential(*list(model.children())[:1])

truncated model will have the first convolution i.e
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))

truncated_model = nn.Sequential(*list(model.children())[:2])

(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
** (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True))**

Continuing this way

truncated_model = nn.Sequential(*list(model.children())[:5])

Sequential (
** (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)**
** (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)**
** (2): ReLU (inplace)**
** (3): MaxPool2d (size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))**
** (4): Sequential (**
** (0): BasicBlock (**
** (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)**
** (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)**
** (relu): ReLU (inplace)**
** (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)**
** (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)**
** )**
** (1): BasicBlock (**
** (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)**
** (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)**
** (relu): ReLU (inplace)**
** (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)**
** (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)**
** )**
** )**
)

I would to truncate my model with the first conv in the basic block. Any ideas?

Because the 5th module in sequential is another sequential, you need to cut that in half.

How does one do it pytorch?