How to transfer the pretrained weights for a standard ResNet50 to a 4-channel

You could replace the first conv layer with a new one using 4 input channels:

model = models.resnet50(pretrained=True)
weight = model.conv1.weight.clone()
model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
with torch.no_grad():
    model.conv1.weight[:, :3] = weight
    model.conv1.weight[:, 3] = model.conv1.weight[:, 0]
    
x = torch.randn(10, 4, 224, 224)
output = model(x)
4 Likes