How to modify the input channels of a Resnet model

Here is a generic function to increase the channels to 4 or more channels.
One key point is that the additional channel weights can be initialized with one original channel rather than being randomized.

new_in_channels = 4
model = models.resnet18(pretrained=True)

layer = model.conv1
        
# Creating new Conv2d layer
new_layer = nn.Conv2d(in_channels=new_in_channels, 
                  out_channels=layer.out_channels, 
                  kernel_size=layer.kernel_size, 
                  stride=layer.stride, 
                  padding=layer.padding,
                  bias=layer.bias)

copy_weights = 0 # Here will initialize the weights from new channel with the red channel weights

# Copying the weights from the old to the new layer
new_layer.weight[:, :layer.in_channels, :, :] = layer.weight.clone()

#Copying the weights of the `copy_weights` channel of the old layer to the extra channels of the new layer
for i in range(new_in_channels - layer.in_channels):
    channel = layer.in_channels + i
    new_layer.weight[:, channel:channel+1, :, :] = layer.weight[:, copy_weights:copy_weights+1, : :].clone()
new_layer.weight = nn.Parameter(new_layer.weight)

model.conv1 = new_layer

This can be modified to work with one or two channels as well.

5 Likes