Parallelizing loop over task-specific batches through task-specific weights


I was hoping for guidance on how to parallelize forward passes of task-specific batches through task-specific weights. My multitask model looks something like this:

class my_model(nn.Module):
     def __init__(self, task_names):
          self.task_names = task_names  # list of strings
          self.shared_layer = nn.Linear(64, 64)
          self.heads = nn.ModuleDict({task_name: nn.Linear(64, 1) for 
                                      task_name in self.task_names})
     def forward(self, in_data, task_name):
          feature = self.shared_layer(in_data)
          return self.heads[task_name](feature)

Now, when I go to train this thing, I loop over task-specific batches, compute the task losses, and then perform a backward pass over the total loss. The code looks like:

task_losses = []
for task_name in task_names:
     in_data, target = next(task_dataloaders[task_name])  # task_dataloaders is a dict
     prediction = my_model(in_data, task_name)
     task_loss = loss(prediction, target)
total_loss = torch.sum(torch.stack(task_losses))

My training seems to be pretty slow for ~20 tasks and I suspect it can be drastically sped up by parallelizing the for loop, possibly on multiple GPUs. Is there a recommended way to do this? I’m not sure how to use the traditional nn.DataParallel method with task-specific weights. Thanks!

Hi @rcmse,

Instead of defining a separate linear layer for each task and looping over it, you can define a single linear layer nn.Linear(64, len(task_names)) that will give you an output of size [batch_sz, len(task_names)].


x = torch.rand(1,5)

# Linear layer for each individual tasks...
l1 = torch.nn.Linear(5,1)
l2 = torch.nn.Linear(5,1)

o1 = l1(x)    # tensor([[0.4597]])
o2 = l2(x)    # tensor([[0.0697]])

# ... is equivalent to one (faster) linear layer

l = torch.nn.Linear(5,2)[0] =[0] =[1] =[1] =

o = l(x)   # tensor([[0.4597, 0.0697]])

Since in multitask learning the inputs are the same for all tasks, you can skip looping over it by modifying your target to be a tensor of shape [batch_sz, len(task_names)] instead of [batch_sz, 1] for each task.

Hi ,

Thanks for your reply! Unfortunately, my inputs are not the same for all tasks - my set of task datasets are all of different size, sometimes by an order of magnitude.

For context, it’s a scientific dataset and I cannot really change it.