Alternate aproach to unbalanced classes

If you are using a batch size of 1, the loss weighting won’t have any effect in the default setup using reduction='mean', since the weighted mean will be calculated.
Here is a small example showing this behavior:

# Setup
weight = torch.tensor([1., 10.])
criterion = nn.CrossEntropyLoss()
criterion_weighted = nn.CrossEntropyLoss(weight)

# All zeros
x = torch.randn(10, 2)
target = torch.zeros(10).long()
loss0_weighted = criterion_weighted(x, target)
loss0 = criterion(x, target)
print(loss0_weighted - loss0) # 0

# All ones
target = torch.ones(10).long()
loss1_weighted = criterion_weighted(x, target)
loss1 = criterion(x, target)
print(loss1_weighted - loss1) # 0

# Mixed
target = torch.randint(0, 2, (10,))
loss_mixed_weighted = criterion_weighted(x, target)
loss_mixed = criterion(x, target)
print(loss_mixed_weighted - loss_mixed) # difference

While the first two runs yield the same loss value, as all samples in the batch have the same target, the last one shows a difference.

You might want to use a WeightedRandomSampler, which could draw more samples from the minority class. Here is an example of its usage.

1 Like