Hello there I am new to PyTorch, I am trying to use pretrained ResNet (2+1)D [1] but since it’s first layer uses 3 channels and I am using only one channel, I suppose I would have to override that class. Please have a look at my attempt, i am getting an error:
TypeError: _video_resnet() got multiple values for keyword argument 'stem'
class R2Plus1dStem4IMAGES(nn.Sequential):
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
"""
def __init__(self):
super(R2Plus1dStem4IMAGES, self).__init__(
nn.Conv3d(1, 45, kernel_size=(1, 7, 7),
stride=(1, 2, 2), padding=(0, 3, 3),
bias=False),
nn.BatchNorm3d(45),
nn.ReLU(inplace=True),
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0),
bias=False),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))
model = torchvision.models.video.r2plus1d_18(pretrained=True, stem=R2Plus1dStem4IMAGES)
model.fc = nn.Linear(model.fc.in_features, 3)
[1] https://pytorch.org/docs/stable/_modules/torchvision/models/video/resnet.html#r2plus1d_18