Model parallelism for branching model

I have a branching model in which several submodules may be executed independently in parallel. Below is the simplest example I could come up with:

class myModule(nn.Module):
    def __init__(self, size_x, size_y=5):
        super(myModule, self).__init__()

        self.f0 = nn.Linear(size_x, size_y)
        self.f1 = nn.Linear(size_x, size_y)

        self.g  = nn.Linear(2*size_y, size_y)

        self.h0 = nn.Linear(2*size_y, size_x)
        self.h1 = nn.Linear(2*size_y, size_x)

    def forward(self, x):
        y0 = self.f0(x[:,0,:])
        y1 = self.f1(x[:,1,:])
        y = torch.cat((y0,y1),dim=1)

        yg = self.g(y)

        out = torch.zeros(x.shape)
        out[:,0,:] = self.h0(torch.cat((y0,yg),dim=1))
        out[:,1,:] = self.h1(torch.cat((y1,yg),dim=1))
        return x

The functions f0 and f1 act independently on separate channels of the input. Function g combines the outputs of f0 and f1 into a single vector fg which is then used by the h functions (along with the output of the corresponding f function) to attempt to recreate the input. I’ve just used linear layers as submodules in the example, but in my application, these are more complicated, are large in their own right, and there are more than 2 of them.

My goal is to execute f0 and f1 in parallel as well as h0 and h1 in parallel. I’m OK with blocking while g executes. I’m particularly interested in how to make sure that the backward pass executes in parallel.

You can use torch.jit.fork for this, however note that your forward pass needs to run in TorchScript for this to actually work asynchronously.

If the two branches are on different GPU devices, then the backward pass would run then in parallel on the two different GPUs. If everything is on CPU, currently the backward pass only has single threaded execution and everything would run on a single thread on the CPU.

Thanks! This was very helpful.