I’m working on multi-task learning. Task A and Task B share the encoder, but decoders are different.
So I have to differentiate the dataset A and dataset B in each iteration.
It seems if I simply do
dataset = ConcatDataset([datasetA, datasetB])
The batch during the enumeration will contain samples from both A and B. But I want a batch to either come from A or B during training.