How to use my own sampler when I already use DistributedSampler?

I want to use my custom sampler (for example, I need oversampling and I want to use this repo: https://github.com/ufoym/imbalanced-dataset-sampler), but I already use DistributedSampler for DataLoader, because I use multi-gpu training. How can I pass to DataLoader one more sampler or maybe I can do it using Dataset? Currently, I use pretty simple ImageFolder dataset and it would be cool if I didn’t need to rewrite it.

You can implement a Wrapper class for your dataset and do the sampling there. For example, if you were to combine DistributedSampler with SubsetRandomSampler, you can implement a dataset wrapper like this:

class DistributedIndicesWrapper(torch.utils.data.Dataset):
    """
    Utility wrapper so that torch.utils.data.distributed.DistributedSampler can work with train test splits
    """
    def __init__(self, dataset: torch.utils.data.Dataset, indices: torch.Tensor):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return self.indices.size(0)

    def __getitem__(self, item):
        # TODO: do the sampling here ?
        idx = self.indices[item]
        return self.dataset[idx]

Thanks for idea, danielhavir!

For everyone who is looking for oversampling wrapper under FolderDataset, you can look at this:

class OversamplingWrapper(torch.utils.data.Dataset):
    def __init__(self, folder_dataset, oversampling_size=1000):
        self.folder_dataset = folder_dataset
        self.oversampling_size = oversampling_size
        self.num_classes = len(folder_dataset.classes)

        self.class_idx_to_sample_ids = {i: [] for i in range(self.num_classes)}
        for idx, (_, class_id) in enumerate(folder_dataset.samples):
            self.class_idx_to_sample_ids[class_id].append(idx)

    def __len__(self):
        return self.num_classes * self.oversampling_size

    def __getitem__(self, index):
        class_id = index % self.num_classes
        sample_idx = random.sample(self.class_idx_to_sample_ids[class_id], 1)
        return self.folder_dataset[sample_idx[0]]
1 Like