DDP for multiple dataloaders with their own loss functions


I am training a multi-task learning model. Due to our data collection strategy, I have several datasets which corresponds to different output branches of model.

I want to apply DDP to accelerate the training process, maybe the standard approach is to use DistributedSampler on each dataloader. I wonder if it is possible to do the following: on each GPU train one specific dataset with its own loss function, and aggregate the losses to update the model.

If possible, are there any tutorials I can follow?

Thanks for your patient and help!

Yes, this should be possible - you’ll likely have to implement your own custom Dataset class as documented here: torch.utils.data.dataset — PyTorch 1.11.0 documentation

For example, this Dataset class can contain your N datasets (per worker), and you can use dist.get_rank() API: Distributed communication package - torch.distributed — PyTorch 1.11.0 documentation to get the right dataset given the worker.

I see, I will try it. Thanks for your suggestions!