Multitask learning with multiple uneven datasets


I want to train a backbone model coupled with several output heads for different tasks.
Each task has a different dataset with varying lengths.

My current approach is to train the model for 10 epochs on each dataset swapping out the head for each specific task.
However this most likely means that the model weights of the backbone trained on the first task is ‘forgotten’ with enough downstream tasks (my guess, perhaps I am mistaken?)

My thoughts on better ways to approach the problem are as follows if someone could point me in the right direction I would greatly appreciate it:

  1. Crop all datasets to the length of the shortest dataset thereby removing the problem of uneven datasets and optimize on the shortened datasets finetuning the head on the full dataset for the specific task at the end. (Pros: Easy to implement and quick, Cons: Smaller data distribution to work with)

  2. Use multiprocessing to train the model on the different datasets at the same time (not sure if this would work ie: update the backbone model weights on each different task effectively updating the backbone model once each step on a different task thereby ensuring that the models ‘memory’ contains a mixture of each task)

Thank you, let me know if I didn’t explain clearly enough and I will try to do better

Approach I ended up taking was to concat the datasets together and organize datapoints that had the same tasks togther for easy seperation later.
ie: in my collate fn I sorted the datapoints so all binary classification datapoints where together and all regression datapoints where togther, Then at training time I iterated over the sharded batch feeding each shard to the backbone model and the respective head model finally backpropagating with the total loss