Using Batch Sampler in a multi-gpu scenario

Hello, I have a piece of code that uses a torch.utils.data.DataLoader with a custom BatchSampler to sample batches with the same amount of objects in each class.

I’m trying to use it in a multi-gpu scenario with NeMo framework. By default when in multi-gpu mode it should be something like this:

if self._placement == DeviceType.AllGpu:
    sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)

self._dataloader = torch.utils.data.DataLoader(
            dataset=self._dataset,            
            sampler=sampler,
            num_workers=num_workers,
        )

I’ve found some tricks to implement a custom distributed sampler, but none of them work for custom distributed batch sampler. What can I do?

@VitalyFedyunin for data loader questions