Proper way of using WeightedRandomSampler()

I am working with an unbalanced dataset and I’m trying to use pytorch wonderful tools to deal with this way too common problem. I am having trouble using the WeightedRandomSampler. I will write my approach so it’s easier to give feedback.

##
## Necessary imports
##
transform = transforms.Compose([
        transforms.RandomResizedCrop(
            size=224, scale=(0.4, 1.6), ratio=(0.9, 1.1)),
         transforms.ColorJitter(brightness=0.1, saturation =0.1,
                               contrast= 0.1, hue=0.3),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

weights = 1 / (torch.FloatTensor([970, 3308, 2407, 212, 4422, 11424, 286, 594, 272])+1e-5)
sampler = torch.utils.data.WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)
# Load the dataset
train_dir = base_dir + "/train"

train_dataset = torchvision.datasets.ImageFolder(train_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams['batch_size'],                                             drop_last=True, num_workers=nw, sampler=sampler)

Now if I run:

>>> len(train_loader)
>>> 0

If I don’t use the sampler the result is not 0 but the number of images divided by the batch size which would be correct result.

Do you see anything that I’m doing wrong? Also, I don’t fully understand the use of the “replacement” argument, if you could spare a moment to explain it I’d be very grateful.

The weight tensor should contain a weight value for each sample, while yours seem to contain the class weights.
Have a look at this example.

With replacement=True, each sample can be picked in each draw again. The number of drawn samples is defined by the num_samples argument.
On the other hand, replacement=False will still use the sample weights to draw the samples, however already picked samples will be removed from the set.

1 Like

From the example you gave, I gather that we need a tensor with length equals to our training dataset and with a values equal to the class weight. as if I have this dataset [0,0,0,0,1,1] then the weight value for each sample would be [0.25,0.25,0.25,0.25,0.5,0.5].

Is there a simple way to get this weight value for each sample knowing the ammount of samples of each class? I could not reproduce your sample code for my 9 class example.

You would need to get the target tensor beforehand to be able to create the weights for each sample.
This can be done offline before training e.g. by iterating the Dataset once.

What errors/issues are you seeing?

I managed to make the weights for each sample but when I get into my training loop and try to get the data the following error promps:

Traceback (most recent call last):
  File "sampler_train.py", line 44, in <module>
    train_losses.append(training_functions.train_epoch(train_loader, network, optimizer, criterion, scheduler, hparams))
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/Code/TFM/ReusableFunctions/training_functions.py", line 9, in train_epoch
    for data, target in train_loader:
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 838, in _next_data
    return self._process_data(data)
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
IndexError: Caught IndexError in DataLoader worker process 5.
Original Traceback (most recent call last):
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/mnt/gpid07/imatge/carlos.hernandez/Documents/base/lib/python3.6/site-packages/torchvision/datasets/folder.py", line 137, in __getitem__
    path, target = self.samples[index]
IndexError: list index out of range

But it only happens when I use the “num_workers” argument of the DataLoader, if I don’t use it, it does not crash.

What is the shape of your weights argument and what length are you passing to num_samples in the sampler?

The shape of the weights is the size of the train dataset, which is a torch.Size[(23903)] and the num_samples size is 23903.

I saw once again the example that you linked and I think I am doing this right but I have to admit that when I start to get into multi-threading (the python version of multi-threading) I get lost.