To counter overfitting on the majority class you could use a weighted loss function (e.g. using the weight
argument in nn.CrossEntropyLoss
) or you could apply weighted sampling via WeightedRandomSampler
as given in this example to balance the dataset.