Hi everyone!
I am trying to build some experiments with PyTorch and PyTorch Lightning. My problem is how to approach a Dataset class that takes as input for example: [dataset_1_train_split.dat, dataset_2_train_split.dat, dataset_3_train_split.dat] . However, d1, d2, d3 are 10, 100 and 1000 lengths each. If I have a batch_size = 4, how can I force every batch to be as balanced as possible?. I assume a “balanced batch” is a batch with samples from every dataset. The epoch should end when all the samples from the biggest dataset have been part of a batch. In the example the batch should be: [sample_from_d1, sample_from_d2, sample_from_d3, sample_from_<any_of_datasets>]. Of course this should be dynamic depending on batch_size and number of input training splits.
Thank you in advance, hope it is possible. If not, how would you approach this problem?
Juan Carlos