Really poor result using transfer learning on medical images

I am using an Inception V3 model pre-trained on ImageNet and when I train it on Bees and Ants dataset from PyTorch training, I get this result:

Dataset statistics:

  * Train
    * Bees: 121
    * Ants: 124
  * Val:
    * Bees: 83
    * Ants: 70

Best validation accuracy: 0.95 at epoch 6 (trained for 100 epochs)

However, when I try it on my own medical images, for also binary classification problem when everything apart from STD/mean vectors are the same, I get really poor results:

* Number of images of size 512x512 w 3 channels
  * Train label +: 34,913
  * Train label -: 27,785
  * Val label +: 93,05
  * Val label -: 5,940

Best val Acc: 0.609708 and best epoch 31

Could you please suggest some ways to look into this problem and try to debug it? What are some methods I could do to improve the accuracy of the validation set? Also, does PyTorch have any built-in tool to measure the domain gap or dissimilarity between ImageNet and each of these two datasets, namely my own in-house medical data and also bees/ants data?

I have used the following formula for calculating the STD and mean vectors for train/test/val and values are shown as below in the comment for my medical data:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        #transforms.Normalize([0.7031, 0.5487, 0.6750], [0.2115, 0.2581, 0.1952])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        #transforms.Normalize([0.7016, 0.5549, 0.6784], [0.2099, 0.2583, 0.1998])
    ]),
    
    'test': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        #transforms.Normalize([0.7048, 0.5509, 0.6763], [0.2111, 0.2576, 0.1979])
    ])
}

# get the mean var std of train, test and val set for data transform
def get_mean_std(loader):
    # VAR[X] = E[X**2] - E[X]**2
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in loader:
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum/num_batches
    std = (channels_squared_sum/num_batches - mean**2)**0.5
    return mean, std

In your assessment, is the difficulty of the problem similar? In other words, to you as a human attempting to perform the task, does it feel comparably easy to classify the two datasets? If the answer is that the medical classification feels far more tricky to a person, that could be the answer (it’s a harder problem and perhaps you need a lot more data to do it right). The rest of the answer assumes the answer is “they’re similarly difficult”.

I recommend you visualize the train, validation and test data after the transformations you’re performing. It’s possible some of those transformation are obscuring the important elements required to correctly classify the images. Specifically I’m talking about the the RandomResizedCrop and CenterCrop. If this happens, then the model effectively doesn’t have access to some essential data (sometimes!) which will interfere with training. Even resizing itself, make sure you as a human can still classify the data with high accuracy once it’s been downsampled to your model’s input size.

If that’s not the issue, perhaps the relevant features you want the model to recognize in your data are on a different scale than those in Bees and Ants dataset. You may try experimenting with different kernel sizes / strides / pooling sizes depending on your own dataset.

Hope this helps!

1 Like

Hi Andrei,

Thanks a lot for your time providing me guidance. I will report back later after I try some of your very helpful suggestions. It seems I am also having some problems with normalization wrt mean and std values. I create a post and share it with you shortly.

Hi Andrei, unfortunately, removing the crop and center crop and randomization and retraining again didn’t help with test set specificity and sensitivity.

This is not some task that me as a computer scientist would be able to figure but it would also be a hard task for a medical doctor.

Unfortunately, I don’t have more images.

Hi Mona,

Sorry to hear that didn’t help. Generically, some things you can try:

  • plotting your validation loss / accuracy by epoch and seeing whether it flatlines or there’s something weirder happening (e.g. it generally declines, but has occasional spikes – which could indicate some of your data is mislabeled). tensorboard and wandb.ai are useful tools here as they take care of the plotting for you.
  • alongside it, plot your training loss as well. if your training loss keeps decreasing while your validation loss is flatlining, it means your model is overfitting (learning spurious characteristics of the training set that are not actually relevant for your classification). depending on how the training / validation data were collected (e.g. in the training set, lots of the + data came from one particular type of machine that has some fingerprint, but it’s not true on valid / test data), this is possible and would be a good insight.
  • explicitly looking at your precision and recall by label, so 4 buckets total. perhaps you will notice that one bucket is responsible for the bulk of the error, which will help you investigate further.
  • run your best model on your own training data and stare at examples of images that it gets right / wrong. see if you notice some pattern.
  • ditto for your validation data.
  • trying other loss functions - you didn’t mention what you’re trying now but I’m assuming it’s something like BCELoss. perhaps you can try something more robust to outliers, like SmoothL1Loss (you’ll have to possibly make a change to your classifier and you’ll have to one-hot encode the target)

Wish you luck!

Best val Acc: 0.604788 and best epoch 50

This is what I am getting from Weight and Biases:

train Loss: 0.7393 Acc: 0.7157

val Loss: 0.6735 Acc: 0.5955

Epoch 48/99

----------

train Loss: 0.7371 Acc: 0.7172

val Loss: 0.6821 Acc: 0.5907

Epoch 49/99

----------

train Loss: 0.7385 Acc: 0.7141

val Loss: 0.6804 Acc: 0.6024

Epoch 50/99

----------

train Loss: 0.7432 Acc: 0.7141

val Loss: 0.6727 Acc: 0.6048

Epoch 51/99

----------

train Loss: 0.7414 Acc: 0.7144

val Loss: 0.6782 Acc: 0.5958

Epoch 52/99

----------

Epoch 99/99

----------

train Loss: 0.7386 Acc: 0.7152

val Loss: 0.6757 Acc: 0.5970

Training complete in 221m 57s

Best val Acc: 0.604788 and best epoch 50

I am using cross-entropy loss. I am using weak label for my images. Meaning that I have a gigapixel image with label 1, I set all the 512x512 patches inside it as label 1 (even if only 10 of this patches might actually have positive labels). Also if the gigapixel image has label 0, I set the label of all of the 512x512 patches as 0.

I tried to follow this paper but I am not getting similar accuracy as they are getting.
Classification and mutation prediction from non-small cell lung cancer histopathology images using deep learning

Your validation loss plot looks really strange, I’ve never seen anything that oscillates like that before. The oscillation seems both huge and regular, which is very puzzling. Are you shuffling your training data or going through it in the same order each epoch? If you’re not shuffling it, it’s almost as if some subset of your training data is messing up the model. I would pull on this thread to identify which subset of the training data is causing the issue. Perhaps you can do this in some quick way by running all your training data through your best model and plotting the images / labels with the biggest losses. Or you can do something more computationally intensive where you assign a score to each training image based on whether training on it (including it in a training batch) helps or hurts the subsequent validation run. That will help you single out the most ‘harmful’ training data points.

For your loss, it sounds like maybe you’re doing some kind of image segmentation loss function. I think this is an atypical approach for image classification. Normally you want one “label” per image, either a 0 or a 1 depending in the class (positive or negative). So your net should output, for each image, a Tensor of size 2, where each element corresponds to the predicted probability of that class (which you can get from a Softmax layer at the end of your classifier). And your target for each image, if you’re using cross-entropy loss, is the true class index for that image (either 0 or 1). This might in principle work out to the same thing if you’re using this gigapixel approach (since there’s only a binary piece of information in your label, even though it’s spread out over 512x512 patches). However, since it’s atypical, I’d recommend trying the standard approach and seeing if that helps.