Dealing with imbalanced datasets in pytorch

Could you explain a bit more about how you use the BCELoss at 3 different levels?
Do you assign something like this:

0.0 - class0
0.5 - class1
1.0 - class2

?
If so, I would recommend to weight the different predictions:

batch_size = 5
nb_classes = 3
output = torch.randn(batch_size, nb_classes)
target = torch.empty(batch_size, nb_classes).random_(2)
weight = torch.tensor([1.0, 2.0, 1.0])

criterion = nn.BCEWithLogitsLoss(reduction='none')
loss = criterion(output, target)
loss = loss * weight
loss = loss.mean()

Would that work for you?

4 Likes