Hi,
I was hoping for guidance on how to parallelize forward passes of task-specific batches through task-specific weights. My multitask model looks something like this:
class my_model(nn.Module):
def __init__(self, task_names):
self.task_names = task_names # list of strings
self.shared_layer = nn.Linear(64, 64)
self.heads = nn.ModuleDict({task_name: nn.Linear(64, 1) for
task_name in self.task_names})
def forward(self, in_data, task_name):
feature = self.shared_layer(in_data)
return self.heads[task_name](feature)
Now, when I go to train this thing, I loop over task-specific batches, compute the task losses, and then perform a backward pass over the total loss. The code looks like:
task_losses = []
for task_name in task_names:
in_data, target = next(task_dataloaders[task_name]) # task_dataloaders is a dict
prediction = my_model(in_data, task_name)
task_loss = loss(prediction, target)
task_losses.append(task_loss)
total_loss = torch.sum(torch.stack(task_losses))
total_loss.backward()
My training seems to be pretty slow for ~20 tasks and I suspect it can be drastically sped up by parallelizing the for
loop, possibly on multiple GPUs. Is there a recommended way to do this? I’m not sure how to use the traditional nn.DataParallel
method with task-specific weights. Thanks!