Sample each image from dataset N times in single batch (with DistributedSampler)

I’m currently working on task of learning representation (deep embeddings). The dataset I use have only one example image per object. I also use augmentation.

During training, each batch must contain N different augmented versions of single image in dataset (dataset[index] always returns new random transformation).

Is there some standard solution or library for this purpose, that will work with torch.utils.data.distributed.DistributedSampler?