Batches of the training set from two datasets

Hello everyone, I’m pretty new to Pytorch.
I have two different datasets, and I need to build my training set where every batch is made of N images, where N/2 are from dataset1 and N/2 from dataset2.
So reading something from here:

I tried to do something similar:

concat = data.ConcatDataset(datasets=(train_set1, train_set2))

mixed_training = data.DataLoader(concat, batch_size=opts.batch_size,
                                   sampler=DistributedSampler(ConcatDataset(train_set1, train_set2),
                                  num_replicas=world_size, rank=rank),

where data.ConcatDataset is the pytorch method, ConcatDataset is implemented using the code in the previous link.

I need a DistributedSampler to use the set_epoch attribute in the training part.
The performances seem pretty good doing so, I just wanted to know if it’s correct or if there’s a simpler/more correct way.
Thank you for the help!