Transfer Learning with imbalanced dataset


I recently started working with pytorch and went through the transfer learning tutorial. I now have that up and running, but the dataset I am working with is very imbalanced and relatively small of 10,000 images. What steps should I take to improve the transfer learning tutorial for a imbalanced dataset?

Thank you!

One approach would be to use a WeightedRandomSampler (here is a small example of its usage) or use weights in your loss function (e.g. nn.CrossEntropyLoss accepts a weight argument).

For using weights in my loss function, what values would I use? Would it be the percentages of each class?

That would be a possibility.
I would recommend to track a valid metric (accuracy might be misleading due to the imbalanced dataset) and try out different weighting.

Would I have to change my training function if I do pass in weights to the loss function?

What do you mean by training function?

The train_model() function given by the tutorial, the fitting function. Would I have to change that?

Ah, OK. No you could just pass your new (weighted) criterion and it should work.

Which method is recommended/gives better metrics? Using the weighted sampler in dataloader OR using weights in the crossentropy loss function?

You should try both methods, as I cannot universally claim one method works better than the other.
Based on past posts in this forum I’ve seen both claims :wink: