RuntimeError with WeightedRandomSampler

Hi, I have implemented the following piece of code to do oversampling of my training dataset that is highly imbalanced.

    if OVERSAMPLING:
        class_weights = 1. / torch.Tensor([25810, 2443, 5292, 873, 708]) # 1/number of samples in each class
        weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights = class_weights, num_samples=NUM_SAMPLES_TRN, replacement=True)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler = weighted_sampler)

However, I am getting a lot of error messages and therefore have a couple of questions regarding things, that I could not clearly understand from the docs or other questions posted here:

  1. what should the argument num_samples in the WeightedRandomSampler value be? batchsize or the size of the whole dataset?

  2. what exactly means replacement=True and why does my code not run with replacement=False - getting:

RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
  1. the above code snippet with replacement=True results in the following batches
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])   
tensor([0, 0, 0, 0, 0, 0, 0, 0])     
tensor([0, 0, 0, 0, 0, 0, 0, 0])     
tensor([0, 0, 0, 0, 0, 0, 0, 0])   
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0]) 
tensor([0, 0, 0, 0, 0, 0, 0, 0]) 
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0]) 
tensor([0, 0, 0, 0, 0, 0, 0, 0])   

Its sampling only from the class 0, which is the most populous. However, I wanted to support sampling from classes, that don’t have that many samples, not this.

Could you please give me hand?

Thanks a lot :blush:
Vojtooo

The weights argument should contain weights for each sample, not just the class weights.
Have a look at this example.

  1. Usually you would pass num_samples = len(dataset), but you could also vary the number of samples from your DataLoader, if that’s your use case.

  2. With replacement = True each drawn sample can be drawn again. On the other hand, replacement = False still uses the weights associated with each sample, but removes a drawn sample from the set.