I am working on a model for multi-task learning that has a shared part and multiple independent branches for the different tasks.
This is a MWE of such a model:
import torch from torch import nn class MT(nn.Module): def __init__(self): super(MT, self).__init__() # Shared layer. self.shared = nn.Linear(in_features=10, out_features=10) # Branch for task 0. self.branch0 = nn.Linear(in_features=10, out_features=1) # Branch for task 1. self.branch1 = nn.Linear(in_features=10, out_features=1) def forward(self, x): out = self.shared(x) out_0 = self.branch0(out) out_1 = self.branch1(out) return out_0, out_1
I have the suspicion that the independent branches are causing the GPU to starve as the shared model is more complex than the branches. Thus, I wonder if PyTorch understands that the two branches
self.branch1(out) are independent and that the forward and backward pass can be computed in parallel. If not, is there a way to achieve parallel execution of the branches?