Resnet pre-trained model and modification

Hello everyone:
I am using resnet34 pretrained model for feature extraction. Added some layers and modified some blocks. In the final training, how to make good use of the weight of the resent, and only train the weight of the modified part. If the input is X = torch.rand(size=(1, 3, 224, 224)), how to enter the model training, thank you

here is coding as follows:

import torchvision.models as models
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

def conv3x3(in_planes, out_planes, stride=1):
“3x3 convolution with padding”
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(BasicBlock, self).__init__()
    self.conv1 = conv3x3(inplanes, planes, stride)
    self.bn1 = nn.BatchNorm2d(planes)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(planes, planes)
    self.bn2 = nn.BatchNorm2d(planes)
    self.downsample = downsample
    self.stride = stride

def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        residual = self.downsample(x)

    out += residual
    out = self.relu(out)

    return out

class CNN(nn.Module):

def __init__(self, block, layers, num_classes=512):
    self.inplanes = 64
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
    self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
    # 新增一个反卷积层
    self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0,
                                             groups=1, bias=False, dilation=1)
    # 新增一个最大池化层
    self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
    # 去掉原来的fc层,新增一个fclass层
    self.fclass = nn.Linear(512, num_classes)

    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  , math.sqrt(2. / n))
        elif isinstance(m, nn.BatchNorm2d):

def _make_layer(self, block, planes, blocks, stride=1):
    downsample = None
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion),

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample))
    self.inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(block(self.inplanes, planes))

    return nn.Sequential(*layers)
def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    # 新加层的forward
    x = x.view(x.size(0), -1)
    #print('x shape', x.shape)
    #x = self.convtranspose1(x)
    #x = self.maxpool2(x)
    #x = x.view(x.size(0), -1)
    x = self.fclass(x)

    return x


resnet34 = models.resnet34(pretrained=True)
#3 4 6 3 分别表示layer1 2 3 4 中BasicBlock模块的数量。res18则为2 2 2 2
cnn = CNN(BasicBlock, [3, 4, 6, 3])

pretrained_dict = resnet34.state_dict()
model_dict = cnn.state_dict()


pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}







X = torch.rand(size=(1, 3, 224, 224))
out = cnn(X)
print(‘out shape’, out.shape)

Hi, you can freeze the weights of the resnet so that only your new layers will be updated during training. Take a look at this example Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 1.11.0+cu102 documentation