Dealing with imbalanced datasets in pytorch

Thank you for helping me out :smile:

This is how I update my weights, each target and y_pred are binary vectors.

 # Compute loss and update parameters for all levels
loss1 = criterion(y_pred1.cuda(), target1)
loss2 = criterion(y_pred2.cuda(), target2)
loss3 = criterion(y_pred3.cuda(), target3)
loss = sum([loss1, loss2, loss3])  # combine all losses
loss.backward()

My loss is criterion = nn.BCELoss(). My idea was to initialize 3 different loss functions instead (criterion1, criterion2 and criterion3) so that I could pass the weight vector right away. Something like

criterion1 = nn.BCELoss(weights1)
criterion2 = nn.BCELoss(weights2)
criterion3 = nn.BCELoss(weights3)
1 Like