Initialising the weights of the first convolutional layer while changing the number of channels in pretrained RESNET

I am trying to change the number of channels of a pre-trained resnet model in pytorch. Here is the code -

from torchvision.models import models
resnet50 = models.resnet50(weights="IMAGENET1K_V2")
resnet.conv1 = nn.Conv2d(
    in_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False
)

This seems to be working. I am verifying it by doing -

summary(model, (1, 224, 224))

This is returning the summary correctly.

However, what I am confused about is that currently, the weights of the first convolutional layer are initialised randomly. Is there a way I can initialise the first layer from the weights of the pretrained 3-channel pytorch model? One way to do it would be to take a mean of the weights from the pretrained model.

The shape of the weight matrix of the first conv layer for 3 channels is torch.Size([64, 3, 7, 7]) whereas for 1 channel it is torch.Size([64, 1, 7, 7]). Does taking the mean across the second dimension make sense?

Hi @Adit_Whorra,
I think you can explore some solutions:

  1. Adapt input. You can convert the 1 channel image to 3 channels one (just replicating the channel).
  2. Adapt first model layer. The solution you suggested. To initialize your first layer, you can try to use the mean or the sum over channels of the 3-channel pre-trained weights.

Evaluate both solutions on ImageNet dataset, then you decide.

Note: All solutions do not require re-training.

The first solution will unnecessarily add extra parameters so I think the second option makes more sense. Will the resultant weights still retain what the model has learnt from pre-training though after taking the mean or sum?

The first solution will unnecessarily add extra parameters so I think the second option makes more sense.

I would not say so. The number of extra parameters is negligible.

Consider the following example:

from torchvision import models
import torch.nn as nn

resnet50_gray = models.resnet50(weights="IMAGENET1K_V2")
resnet50_gray.conv1 = nn.Conv2d(
    1, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False
)

resnet50_rgb = models.resnet50(weights="IMAGENET1K_V2")
>>> print("gray", count_parameters(resnet50_gray))
gray: 25550760
>>> print("rgb", count_parameters(resnet50_rgb))
rgb:  25557032

Will the resultant weights still retain what the model has learnt from pre-training though after taking the mean or sum?

The main point IMO is to compute how many accuracy points are lost after that operation. That’s why use should try all the alternatives.