How to transfer the pretrained weights for a standard ResNet50 to a 4-channel

In Remote sensing, the image usually has more than three channels. For example, the image has NIR ,R ,G and B. I want to leveraged on the pretrained weights for a standard ResNet50 and transferred them to a 4-channel input version by copying RGB weights + the NIR weight as equivalent to the red channel.How to solve it?

You could replace the first conv layer with a new one using 4 input channels:

model = models.resnet50(pretrained=True)
weight = model.conv1.weight.clone()
model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
with torch.no_grad():
    model.conv1.weight[:, :3] = weight
    model.conv1.weight[:, 3] = model.conv1.weight[:, 0]
    
x = torch.randn(10, 4, 224, 224)
output = model(x)
4 Likes

Thank you for your suggestion.it works.

HI,

why the use of torch.no_grad(), and does affect the training if i want to fine tune my model?

Thanks

1 Like

@ptrblck why did you use with torch.no_grad() ? Is it because here you use if for inference ? If I want to re-train the network after initializing the weights like this, can I just assign the weight without using with torch.no_grad() ?

I think it is to avoid a runtime error.
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation

Yes, you would need to warp the assignment into the no_grad() guard to avoid Autograd tracking this operation.