I recently started working with pytorch and went through the transfer learning tutorial. I now have that up and running, but the dataset I am working with is very imbalanced and relatively small of 10,000 images. What steps should I take to improve the transfer learning tutorial for a imbalanced dataset?
One approach would be to use a
WeightedRandomSampler (here is a small example of its usage) or use weights in your loss function (e.g.
nn.CrossEntropyLoss accepts a
For using weights in my loss function, what values would I use? Would it be the percentages of each class?
That would be a possibility.
I would recommend to track a valid metric (accuracy might be misleading due to the imbalanced dataset) and try out different weighting.
Would I have to change my training function if I do pass in weights to the loss function?
What do you mean by training function?
The train_model() function given by the tutorial, the fitting function. Would I have to change that?
Ah, OK. No you could just pass your new (weighted) criterion and it should work.
Which method is recommended/gives better metrics? Using the weighted sampler in dataloader OR using weights in the crossentropy loss function?
You should try both methods, as I cannot universally claim one method works better than the other.
Based on past posts in this forum I’ve seen both claims