Pytorch dataset sampler for multi task learning

I have one main feature extractor x which has two tasks A and B. I have two different datasets(images) d1 and d2 which are highly unbalanced(based on number of batches). say 5000 images in d1 and 100000 in d2.
I need to pass d1 to x+A and then d2 x+B separately and backpropogate
How do i design the data loader to train the network??

1 Like

I think that the easiest solution would be to use one DataLoader per Dataset.
In fact, if you use only one dataloader from both tasks, the problem is that one batch can be composed of images from d1 and from d2. It is ok to have images from both datasets for your main feature extractor but it becomes painful when you have to separate the produced features for tasks A and B since splitting the batch will produce batches of different lengths, with eventually empty batches, etc.

If you really want to use only one unique dataloader, you can have a look at ConcatDatasets.

(post withdrawn by author, will be automatically deleted in 24 hours unless flagged)