WeightedRandomSampler for custom image dataloader

I am trying to solve class imbalance by using Weighted Random Sampler on a custom data loader for multiclass image classification. I can’t seem to find the best way to implement this. The images are in a folder and labels are in a csv file. The dataloader code without the weighted random sampler is given below.

class CassavaDataset(Dataset):
    def __init__(self, df, data_root, transforms=None, output_label=True):
        super().__init__()
        self.df = df.reset_index(drop=True).copy() # data
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
 
    def __len__(self):
        return self.df.shape[0] # or len(self.df)
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.df.iloc[index]['label']
          
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        
        img  = get_img(path)

        if self.transforms:
            img = self.transforms(image=img)['image']
           
        # do label smoothing
        if self.output_label == True:
            return img, target
        else:
            return img

What will be the best way to get weights of each class and feed it to the sampler before augmentation? Thanks in advance!

I know it is hacky, but I usually just enlarge the dataset to repeat indices that should have higher weight.
Or you could have to “leak” the weights from the dataset to the sampler.

Best regards

Thomas

Hey Thomas thanks for the reply but I’m not sure I understand. Could you elaborate or give an example of some sort?

I don’t know much about your dataset, but here is a very simple two-class example:

Say you have class 1 at indices 1…9_999 and class 2 at indices 10_000…10_048, you could do the following at the top of getitem:

index = index if index < 10_000 else (index - 10_000) % 49

and then set the __len__ to return 20_000. This would give you a dataset with class 2 to be oversampled to the weight of class 1.

Best regards

Thomas