Execute multi-task branches in parallel

I am working on a model for multi-task learning that has a shared part and multiple independent branches for the different tasks.

This is a MWE of such a model:

import torch
from torch import nn

class MT(nn.Module):
    def __init__(self):
        super(MT, self).__init__()
        # Shared layer.
        self.shared = nn.Linear(in_features=10, out_features=10)
        # Branch for task 0.
        self.branch0 = nn.Linear(in_features=10, out_features=1)
        # Branch for task 1.
        self.branch1 = nn.Linear(in_features=10, out_features=1)
    def forward(self, x):
        out = self.shared(x)
        out_0 = self.branch0(out)
        out_1 = self.branch1(out)
        return out_0, out_1

I have the suspicion that the independent branches are causing the GPU to starve as the shared model is more complex than the branches. Thus, I wonder if PyTorch understands that the two branches self.branch0(out) and self.branch1(out) are independent and that the forward and backward pass can be computed in parallel. If not, is there a way to achieve parallel execution of the branches?

7 Likes