How to tackle the class imbalance problem during training in PyTorch

You could over/undersample the classes using a WeightedRandomSampler.
Have a look at this post for a small example.