Can I use the pre-training of weights as I add input channels?

Hello !

I’m trying to use a ResNet18 model to process Sentinel-2 satellite imagery for a regression problem. I have weights that were pre-trained on 13 channels, but I’d like to add channels for additional information (temperature, sea depth etc.). As I have a very limited amount of data, I can’t afford to re-train the whole model, and I’d like to use those weights.

Looking around on forums, I found that the common solution is to modify the first layer to accept more channels, then clone the pre-trained weights, and initialize the new weights (either randomly or using the first weights). However, i’m not sure about about the integrity of the weights in the next layers. If I understand correctly, the weights for each layer after the first one will initially stay the same. Now if I feed the model new information in the form of my new channels, won’t this mess up the pre-trained weights ? Should I freeze some of them, change the learning rate ?..

Thanks in advance for any input !

I think your concern is valid and I would also expect that at least finetuning of the first layer is needed. I don’t know if retraining the already pretrained layers is needed or helps.
As an alternative you could also consider using a new “mapping” layer, which would transform your inputs channels to the desired one used in your pretrained model. If you are using a kernel size of 1x1 for this layer you could see it as a “trainable color transformation”.

Thank you for your reply. I’m concerned that this mapping layer would introduce error in the final prediction, as some of my new channels are not at all correlated with the imagery. Would it be possible, with torch, to run two CNNs in parallel with a common loss function, and merge their output at the last layer ? This way i would keep one CNN on the 13 band Sentinel imagery with the frozen pre-trained weights, and train the other from scratch in parallel on the new channels only

Yes, you can create such an architecture, but would then of course need to check how these features should be handled. E.g. you could concatenate them or use any reduction operation.

Thank you, I will try a few things on parallel networks and compare the fusion methods.