Fine tuning pretrained RestNet for grayscale image classification

I am working with a model that combines RestNet and LSTM for image classification, i used the ResNet and just changed the first layer as suggested here post but now I want to fine tune some layers in the post they said its not possible to fine tune but I saw a suggestion that you can average the weights over the three channels and then use them,But i didn’t understand how it works
this is wht i have done so far

class CNNLSTM(nn.Module):
    def __init__(self, num_classes=5):
        super(CNNLSTM, self).__init__()

        # Load pretrained ResNet-50
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
   
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
       # i need the Resnet only for feature extraction 
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-1]) 

        self.lstm = nn.LSTM(
            input_size=2048,
            hidden_size=256,
            num_layers=2,
            batch_first=True           
        )

        self.fc = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        
        batch_size, timesteps, C, H, W = x.shape

        x = x.view(batch_size * timesteps, C, H, W)

        
        with torch.no_grad():  
            features = self.feature_extractor(x)  
        features = features.view(batch_size, timesteps, -1) 

        lstm_out, _ = self.lstm(features)  
        last_out = lstm_out[:, -1, :]      

        # Classification
        output = self.fc(last_out)    
        return output

does anyone have any suggestions

I don’t understand what exactly your use case is. If you want to reuse the first conv layer and average the channels, you could do:

model = models.resnet18()

conv = copy.deepcopy(model.conv1)
print(conv.weight.shape)
# torch.Size([64, 3, 7, 7])
print(conv.bias)
# None

conv.weight = nn.Parameter(conv.weight.mean(1).unsqueeze(1))
print(conv.weight.shape)
# torch.Size([64, 1, 7, 7])

model.conv1 = conv

x = torch.randn(1, 1, 224, 224)
out = model(x)
print(out.shape)
# torch.Size([1, 1000])

However, I don’t know if averaging pretrained RGB channels into a single one would outperform training the grayscale layer from scratch and you should compare both approaches.