Balanced batch sampling with DistributedSampler/DDP

Hello,

I am trying to use a balanced/pk batch sampler for triplet loss learning, as in Hermans et al. 2017 (https://arxiv.org/pdf/1703.07737). However, I am using multiple GPUs with DistributedDataParallel, and have been unable to find an implementation that would work with DistributedSampler or DDP in general. I have seen implementations like DistributedSamplerWrapper from Catalyst, but I am unsure that a balanced sampler would work properly with it.

This post explaining an approach to create a DistributedWeightedSampler might be helpful.