Sample concat dataset and sample one batch from one dataset at a time

I’m working on multi-task learning. Task A and Task B share the encoder, but decoders are different.
So I have to differentiate the dataset A and dataset B in each iteration.

It seems if I simply do

dataset = ConcatDataset([datasetA, datasetB])

The batch during the enumeration will contain samples from both A and B. But I want a batch to either come from A or B during training.

Any suggestions?

You could use two DataLoaders, create the iterators using loader_iter = iter(loader), and grab the next batch in each iteration as from the loader you want via next(loader_iter).
This approach would give you the flexibility to apply complicated conditions when to use which dataset.
On the other hand, if you want to switch between both datasets in each iteration, you could create a custom sampler and create the indices for the ConcatDataset as you wish.
In the simplest case you could return the indices as: [dataA_idx0, dataA_idx1, dataA_idx2, ... dataB_idx0, ...].