Uniform sampling in multi-label classification problem

Hi,

I work on a dataset with images and associated one-hot vectors of labels (code below). I would like to uniformly sample a label from the labels set which is consisted of 300 unique labels and then randomly pick an image whose one-hot vector possesses selected label. E.g. in first step I sample index=2 and in the second step I sample image with one-hot vector label = [0,0,1,0,1,…]. If I understand the documentation of built-in samplers correctly, they can be used only in map-style datasets. Should I write a custom dataloader or rather custom sampler for this kind of problem? Any help appreciated.

class datasetShop(Dataset):
    
    # df_csv - df with data description
    # root_dir - directory to the folder with images
    # transform - data transformation function
    
    def __init__(self, df_csv, root_dir, transform=None):
        
        self.df_csv = df_csv  
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df_csv)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = self.df_csv['ID'].iloc[idx] + '.jpg'
        img_name = os.path.join(self.root_dir, img_name)
        
        image = Image.open(img_name)
    
        one_hot = torch.Tensor(self.df_csv['one_hot'].iloc[idx])
        
        if self.transform:
            image = self.transform(image)            
        
        sample = {'image':image, 'label':one_hot}

        return sample
1 Like

I am dealing with a similar, but simpler issue.
I would like to sample by label, without one-hot encoding.
the labels are stored in some “df_csv”, and served by “getitem()”
how can we point the sampler to the correct column in df_csv or otherwise sample according to items’ labels?
@ptrblck

The sampler would be responsible to pass the sample indices to the Dataset.__getitem__, where you would be loading and processing the data.
Inside __getitem__ you could use the index and access the df_csv as needed. E.g. you could index it via pandas.