Using WeightedRandomSampler with ConcatDataset

For weighted sampling you would have to create a weight for each sample.
If you don’t have the target tensors already computed, you could iterate your dataset and store the target tensors.
Here is a small example, which should match your use case:

# Create dummy data with class imbalance 99 to 1
numDataPoints = 1000
data_dim = 5
bs = 100
data = torch.randn(numDataPoints, data_dim)
target = torch.cat((torch.zeros(int(numDataPoints * 0.99), dtype=torch.long),
                    torch.ones(int(numDataPoints * 0.01), dtype=torch.long)))

print('target train 0/1: {}/{}'.format(
    (target == 0).sum(), (target == 1).sum()))

# Create ConcatDataset
dataset = torch.utils.data.TensorDataset(data, target)
train_dataset = ConcatDataset((dataset, dataset))

# Get all targets
targets = []
for _, target in train_dataset:
    targets.append(target)
targets = torch.stack(targets)

# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
    [(targets == t).sum() for t in torch.unique(targets, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in targets])

# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1: {}/{}".format(
        i, (y == 0).sum(), (y == 1).sum()))

In the first part I’m creating a dummy imbalanced dataset.
You should of course just skip this step and use your original concatDataset.

After storing all targets, the class_sample_count and the corresponding samples_weight tensor is created, which is used to create the WeightedRandomSampler.
As you can see in the last loop, each batch should be balanced using the sampler.

Let me know, if that would work for you.

1 Like