I recently started working with pytorch and went through the transfer learning tutorial. I now have that up and running, but the dataset I am working with is very imbalanced and relatively small of 10,000 images. What steps should I take to improve the transfer learning tutorial for a imbalanced dataset?
One approach would be to use a WeightedRandomSampler (here is a small example of its usage) or use weights in your loss function (e.g. nn.CrossEntropyLoss accepts a weight argument).
That would be a possibility.
I would recommend to track a valid metric (accuracy might be misleading due to the imbalanced dataset) and try out different weighting.
You should try both methods, as I cannot universally claim one method works better than the other.
Based on past posts in this forum I’ve seen both claims