Some problems with WeightedRandomSampler

The weights tensor will contain the reciprocal of the class counts.
So let’s say your class distribution is:

class0 = 100
class1 = 10
class2 = 1000

class_counts = [class0, class1, class2]
weight = 1. / torch.tensor(class_counts).float()
print(weight)
> tensor([0.0100, 0.1000, 0.0010])

As you can see, class1 with the lowest number of samples has the highest weight now.

However, for the WeightedRandomSampler we need to provide a weight for each sample.
So if your target is defined as:

target = torch.cat((
    torch.zeros(class0), torch.ones(class1), torch.ones(class2)*2.)).long()
# shuffle
target = target[torch.randperm(len(target))]

we can directly index weight to get the corresponding weight for each target sample:

# Get corresponding weight for each target
sample_weight = weight[target]
2 Likes