Multi channel image classification model

I am using reset50 to classify multi channel(10) images by changing input channel from 3 to 10 but I am getting less accuracy(35%) on test data, I am starting to wonder that reset is designed for 3 channels and might not work for 10 channel, My question is can someone suggest which models are used for training such high channel images ?

What do your channels represent?

bands(wavelength) of satelliet images

Got it, and are you already using the pretrained resnet50 weights everywhere except your first convolution and the classifier?

yes I load reset50 with pretrained weight then set first layer(cov1) to 10 channels and last layer to my size of classes, which discard pretrained weights for only these two layers.

OK, sounds like a reasonable thing. Are, separate from the extra channels, are the types of objects you’re looking to classify similarly shaped to what the Resnet was trained on, or totally different? If they’re totally different, it’s possible that your later Resnet pretrained features are inappropriate for your problem. The early layers might be OK (modulo the channel difference) since they are edge / circle primitives, but more complicated features downstream might be not appropriate.

Something you can try doing the following (inspired by fastai / Kaggle best practices):

  • train for some epochs while freezing the remaining pretrained weights (2nd, 3rd, …, n-1 layer frozen, 1st and nth layer unfrozen)
  • save the state dict here
  • experiment with training from here onwards using differential learning rates when you unfreeze the 2nd - n-1th layer. specifically give a high learning rate to the new layers, and to the later layers (n-4, n-3, etc.)
  • experiment with resetting the later Resnet features, per my initial comment above. they may not be what you want for your problem.

Separate from all of this, make sure you’re doing all the typical training tricks (data augmentation and so on) and check that those tricks are adapted to your problem. For example, for satellite imagery I imagine shearing is not particularly useful (since that simulates different 3d perspectives, but satellite is flat) but instead you want to rotate your images up to 360 degrees and flipping them.

If it turns out that the network capacity is insufficient (though I’d be surprised) you can even experiment with adding more channels to the early resnet layers. You then have to carefully put the pretrained weights to only the part of the layer that matches the size (you can’t just assign them to the entire layer anymore, since there will be a size mismatch).

Hoppe this helps!

Actually training error is low and test error error is high, I also tried resnet18 but same results in hope of overfitting, I am training for multi label and data is very unbalanced, I notices per class accuracy is high for class with less samples and low for classes with high samples, so I am first going the route of balancing multi label dataset then think of blaming archietecture which has a lower chance of being the culprit.

Actually it is overfitting but not because of network capacity being high.

Sounds like a useful observation and a good plan. More aggressive data augmentation might also help with mitigating overfitting, to the extent that’s possible.

If you’re using a standard train/valid/test split you should be able to track per-class valid accuracy and see the inflexion point when overfitting begins.