How to duplicate the R channel weights in DeepLapV3?

I’ve changed the number of input channels in DeepLabV3, because I have NRGB image (N=NIR=Near infrared), the next step is I have to duplicate the weights for the Red channel to make the model be fully pre-trained.

Any ideas ?

import torch
import torch.nn as nn

class MyDeepLab(nn.Module):

    def __init__(self, in_channels=1):
        super(MyDeepLab, self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.5.0', 'deeplabv3_resnet101', pretrained=True)
        self.model.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=2, padding=1, bias=False)

    def forward(self, x):
        return self.model(x)


my_deeplab = MyDeepLab()
print(my_deeplab.state_dict())

This code should work:


class MyDeepLab(nn.Module):

    def __init__(self, in_channels=1):
        super(MyDeepLab, self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.5.0', 'deeplabv3_resnet101', pretrained=True)
        with torch.no_grad():
            conv_weight = self.model.backbone.conv1.weight.clone()
            self.model.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.model.backbone.conv1.weight[:, 1:] = conv_weight
            self.model.backbone.conv1.weight[:, 0:1] = conv_weight[:, 0:1]
        

    def forward(self, x):
        return self.model(x)

It will repeat the first channel and reuse the others.
Note that you have a different kernel shape and padding in your code than the original first conv layer.