How to create a DistributedSampler based on my own Sampler

I have implement a Sampler when using single GPU. I implement it because I want all samples in a batch are from the same source.

class MultiSourceBatchSampler(torch.utils.data.Sampler):
    def __init__(
        self,
        data_source: Dataset,
        # sources: list,
        batch_size: int,
        shuffle: bool = False,
        drop_last: bool = False,
    ):
        self.data_source = data_source
        # self.sources = sources
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        self._length = -1
    
    def __iter__(self):
        # indices are tuples (source, second_key)
        indices = self.data_source._get_indices()
        if self.shuffle:
            np.random.shuffle(indices)
        indices.sort(key=lambda x: x[0]) #sort by source

        batched_indices = []
        buffer = []
        for index in indices:
            if (len(buffer) == self.batch_size) or (len(buffer) > 0 and buffer[0][0] != index[0]):
                batched_indices.append(buffer)
                buffer = []
            buffer.append(index)
        if len(buffer) > 0 and not self.drop_last:
            batched_indices.append(buffer)
       
        if self.shuffle:
            np.random.shuffle(batched_indices)
        yield from batched_indices

    def __len__(self):
        if self._length >= 0:
            return self._length
        total = 0
        indices = self.data_source._get_keys(by=self.by)
        indices.sort(key=lambda x: x[0]) #sort by source

        buffer = []
        for index in indices:
            if (len(buffer) == self.batch_size) or (len(buffer) > 0 and buffer[0][0] != index[0]):
                total += 1
                buffer = []
            buffer.append(index)
        if len(buffer) > 0 and not self.drop_last:
            total += 1
        self._length = total
        return total

I did not find very detailed documentation about DistributedSampler, so I wonder what is the principle if I want to convert my sampler to a DistributedSampler

The DistributedSampler splits the dataset indices based on the number of replicas making sure each rank receives the same number of samples. The actual self.total_size depends on the number of samples in the dataset, the number of replicas and if the last samples should be dropped or repeated as seen here. It also holds the seed allowing you to reset the epoch to continue training from the same epoch.

A possible general solution:

class DistributedSamplerWrapper(torch.utils.data.distributed.DistributedSampler):
    def __init__(
        self,
        base_sampler,
        num_replicas = None,
        rank = None,
        seed = 0,
    ):
        self.base_sampler = base_sampler
        self.batch_size = base_sampler.batch_size
        shuffle = self.base_sampler.shuffle
        drop_last = self.base_sampler.drop_last
        super().__init__(base_sampler, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last)

    def __iter__(self):
        base_indices = list(self.base_sampler.__iter__())
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]
        indices = [base_indices[i] for i in indices]
            
        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[
                    :padding_size
                ]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

Should be passed as the batch_sampler argument.