Using a pre-trained network for input with more than 3 channels


I have a satellite dataset with images containing > 3 channels. Since the data is limited, I am using pre-trained models. Also, I am not training the pre-trained model from scratch due to small dataset and freezing the gradient for initial layers. In such a case, is there a way to use/fuse the remaining channels of my image and use it with the pre-trained model?
Any inputs would be appreciated.


It would depend on the use case and what kind of information is stored in each channel.
You could certainly try to reduce the channels to three, but would need to check, if you are losing valid information. Another approach would be to either replace the first layer with a new conv accepting more input channels, manipulating the kernels in the first layer and add new randomly initialized channels, or to add a complete new “0th” layer to transform the images to 3 channels (let’s call this a “trainable reduction”) using e.g. a 1x1 kernel.

Thanks for the input. I can try adding a new 0th layer or manipulating the first layer as you suggested. I would need the information from the other channels as they contain information from infra red band and help in understanding the landscape information.
Since I am a PyTorch novice, could you point me to an example where this is done? It would help me a great deal!

Thanks in advance.


Something like this should work:

model = models.resnet18()
conv0 = nn.Conv2d(4, 3, 1)

x = torch.randn(1, 4, 224, 224)
out = conv0(x)
out = model(out)

You could of course wrap it in a new custom module, which would be a bit cleaner.

1 Like