I am working on a model for multi-task learning that has a shared part and multiple independent branches for the different tasks.
This is a MWE of such a model:
import torch
from torch import nn
class MT(nn.Module):
def __init__(self):
super(MT, self).__init__()
# Shared layer.
self.shared = nn.Linear(in_features=10, out_features=10)
# Branch for task 0.
self.branch0 = nn.Linear(in_features=10, out_features=1)
# Branch for task 1.
self.branch1 = nn.Linear(in_features=10, out_features=1)
def forward(self, x):
out = self.shared(x)
out_0 = self.branch0(out)
out_1 = self.branch1(out)
return out_0, out_1
I have the suspicion that the independent branches are causing the GPU to starve as the shared model is more complex than the branches. Thus, I wonder if PyTorch understands that the two branches self.branch0(out)
and self.branch1(out)
are independent and that the forward and backward pass can be computed in parallel. If not, is there a way to achieve parallel execution of the branches?