Batch balancing with different length datasets

Hi everyone!

I am trying to build some experiments with PyTorch and PyTorch Lightning. My problem is how to approach a Dataset class that takes as input for example: [dataset_1_train_split.dat, dataset_2_train_split.dat, dataset_3_train_split.dat] . However, d1, d2, d3 are 10, 100 and 1000 lengths each. If I have a batch_size = 4, how can I force every batch to be as balanced as possible?. I assume a “balanced batch” is a batch with samples from every dataset. The epoch should end when all the samples from the biggest dataset have been part of a batch. In the example the batch should be: [sample_from_d1, sample_from_d2, sample_from_d3, sample_from_<any_of_datasets>]. Of course this should be dynamic depending on batch_size and number of input training splits.

Thank you in advance, hope it is possible. If not, how would you approach this problem?
Juan Carlos

In case someone needs to do something similar. I ended creating an interleaved list between all the datasets and use it as the only dataset for the experiment.