So my dataset is very unbalanced:
class-1 samples: 76
class-2 samples: 259
now before applying weighted loss function for each class, I want to know does it make any sense to use class
weight when dataset is this much unbalanced ? Specially when my
(I can’t increase batch_size because of my system resources).
are there any other approaches I can use in that case ?
If you are using a batch size of 1, the loss weighting won’t have any effect in the default setup using
reduction='mean', since the weighted mean will be calculated.
Here is a small example showing this behavior:
weight = torch.tensor([1., 10.])
criterion = nn.CrossEntropyLoss()
criterion_weighted = nn.CrossEntropyLoss(weight)
# All zeros
x = torch.randn(10, 2)
target = torch.zeros(10).long()
loss0_weighted = criterion_weighted(x, target)
loss0 = criterion(x, target)
print(loss0_weighted - loss0) # 0
# All ones
target = torch.ones(10).long()
loss1_weighted = criterion_weighted(x, target)
loss1 = criterion(x, target)
print(loss1_weighted - loss1) # 0
target = torch.randint(0, 2, (10,))
loss_mixed_weighted = criterion_weighted(x, target)
loss_mixed = criterion(x, target)
print(loss_mixed_weighted - loss_mixed) # difference
While the first two runs yield the same loss value, as all samples in the batch have the same target, the last one shows a difference.
You might want to use a
WeightedRandomSampler, which could draw more samples from the minority class. Here is an example of its usage.