The weights
tensor will contain the reciprocal of the class counts.
So let’s say your class distribution is:
class0 = 100
class1 = 10
class2 = 1000
class_counts = [class0, class1, class2]
weight = 1. / torch.tensor(class_counts).float()
print(weight)
> tensor([0.0100, 0.1000, 0.0010])
As you can see, class1 with the lowest number of samples has the highest weight now.
However, for the WeightedRandomSampler
we need to provide a weight for each sample.
So if your target is defined as:
target = torch.cat((
torch.zeros(class0), torch.ones(class1), torch.ones(class2)*2.)).long()
# shuffle
target = target[torch.randperm(len(target))]
we can directly index weight
to get the corresponding weight for each target sample:
# Get corresponding weight for each target
sample_weight = weight[target]