# Parallelizing loop over task-specific batches through task-specific weights

Hi,

I was hoping for guidance on how to parallelize forward passes of task-specific batches through task-specific weights. My multitask model looks something like this:

``````class my_model(nn.Module):
self.task_names = task_names  # list of strings
self.shared_layer = nn.Linear(64, 64)
def forward(self, in_data, task_name):
feature = self.shared_layer(in_data)
``````

Now, when I go to train this thing, I loop over task-specific batches, compute the task losses, and then perform a backward pass over the total loss. The code looks like:

``````task_losses = []
prediction = my_model(in_data, task_name)
task_loss = loss(prediction, target)
total_loss.backward()
``````

My training seems to be pretty slow for ~20 tasks and I suspect it can be drastically sped up by parallelizing the `for` loop, possibly on multiple GPUs. Is there a recommended way to do this? I’m not sure how to use the traditional `nn.DataParallel` method with task-specific weights. Thanks!

Hi @rcmse,

Instead of defining a separate linear layer for each task and looping over it, you can define a single linear layer `nn.Linear(64, len(task_names))` that will give you an output of size `[batch_sz, len(task_names)]`.

Example:

``````x = torch.rand(1,5)

# Linear layer for each individual tasks...
l1 = torch.nn.Linear(5,1)
l2 = torch.nn.Linear(5,1)

o1 = l1(x)    # tensor([[0.4597]])
o2 = l2(x)    # tensor([[0.0697]])

# ... is equivalent to one (faster) linear layer

l = torch.nn.Linear(5,2)
l.weight.data[0] = l1.weight.data
l.bias.data[0] = l1.bias.data
l.weight.data[1] = l2.weight.data
l.bias.data[1] = l2.bias.data

o = l(x)   # tensor([[0.4597, 0.0697]])
``````

Since in multitask learning the inputs are the same for all tasks, you can skip looping over it by modifying your target to be a tensor of shape `[batch_sz, len(task_names)]` instead of `[batch_sz, 1]` for each task.

Hi @suraj.pt ,

Thanks for your reply! Unfortunately, my inputs are not the same for all tasks - my set of task datasets are all of different size, sometimes by an order of magnitude.

For context, it’s a scientific dataset and I cannot really change it.