Sampling with replacement

I think it will still work. If I’m not misunderstanding your concern, that should be exactly how the WeightedRandomSampler works.
I’ve adapted an old example using two highly imbalanced classes:

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 99 to 1
data = torch.randn(numDataPoints, data_dim)
target = torch.cat((torch.zeros(int(numDataPoints * 0.99), dtype=torch.long),
                    torch.ones(int(numDataPoints * 0.01), dtype=torch.long)))

print('target train 0/1: {}/{}'.format(
    (target == 0).sum(), (target == 1).sum()))

class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])

sampler = WeightedRandomSampler(samples_weight, int(len(samples_weight)))

train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1: {}/{}".format(
        i, (y == 0).sum(), (y == 1).sum()))

Even though the classes are distributed in 99 to 1, each each will contain approx. a balanced set of classes.
Let me know, if you could adapt this code to your use case.

1 Like