WeightedRandomSampler only sampling from one class

I’m running into an issue where WeightedRandomSampler is only sampling data from one class.

I’ve tried switching it for shuffle=True, and that works just fine, but I need to use weighted sampling since my data is imbalanced.
Here’s the relevant code:

batch_size = 64

transform = torch.jit.script(nn.Sequential(
    RandAugment(num_ops=5, magnitude=10),
    NormalizeImage()
))

sample_count = [len(os.listdir(f"../../data/train_data/{i}")) for i in range(9)]
weights = torch.tensor(sample_count).reciprocal()

dataset = ImageFolder(root="../../data/train_data/", loader=read_image, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, sampler=WeightedRandomSampler(weights=weights, num_samples=batch_size), pin_memory=True)

data, labels = next(iter(loader))
data = data.cuda()
labels = labels.cuda()
print(labels)

The output of this is

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

No Bueno :frowning:

If it’s helpful, my sample_count is

[1025, 964, 931, 922, 962, 909, 974, 916, 8306] 

and my weights are

tensor([0.0010, 0.0010, 0.0011, 0.0011, 0.0010, 0.0011, 0.0010, 0.0011, 0.0001])

Also, like I mentioned before, if I change my data loader to shuffle=True

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

I get the expected outcome of

tensor([8, 3, 8, 8, 8, 8, 8, 2, 8, 8, 8, 2, 4, 8, 8, 8, 1, 8, 8, 3, 0, 8, 3, 2,
        1, 8, 8, 5, 8, 8, 8, 8, 8, 8, 7, 5, 8, 4, 3, 8, 8, 8, 8, 8, 3, 3, 4, 0,
        7, 8, 3, 2, 8, 8, 5, 5, 3, 8, 8, 8, 8, 4, 2, 5], device='cuda:0')

(For clarity, I have 9 classes numbered 0 through 8 each with their own subfolder in the train_data directory)
Any help would be appreciated!

WeightedRandomSampler expects a weight tensor assigning the weight values to each sample, not the class index.
Have a look at this post which gives you an example.

1 Like