How to override pretrained model class in Pytorch?

Hello there I am new to PyTorch, I am trying to use pretrained ResNet (2+1)D [1] but since it’s first layer uses 3 channels and I am using only one channel, I suppose I would have to override that class. Please have a look at my attempt, i am getting an error:

TypeError: _video_resnet() got multiple values for keyword argument 'stem'
class R2Plus1dStem4IMAGES(nn.Sequential):
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
    """
    def __init__(self):
        super(R2Plus1dStem4IMAGES, self).__init__(
            nn.Conv3d(1, 45, kernel_size=(1, 7, 7),
                      stride=(1, 2, 2), padding=(0, 3, 3),
                      bias=False),
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
                      stride=(1, 1, 1), padding=(1, 0, 0),
                      bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True))



model = torchvision.models.video.r2plus1d_18(pretrained=True, stem=R2Plus1dStem4IMAGES)

model.fc = nn.Linear(model.fc.in_features, 3)

[1] https://pytorch.org/docs/stable/_modules/torchvision/models/video/resnet.html#r2plus1d_18

Glancing at the code in torchvision, it looks like r2plus1d_18 doesn’t support overriding the stem. So you should probably just try calling the underlying _video_resnet directly with the same arguments as specified in r2plus1d_18

Hi Kahn, I guess you are looking to use the pre-trained weights within the ResNet architecture. The simplest thing would be to expand your tensor with the input data by replicating the one channel:

# make up some input tensor
input_tensor = torch.randn(batch_size, 1, img_size, img_size)
input_tensor = input_tensor.expand(batch_size, 3, img_size, img_size)
# go on to use in resnet ...