I have a branching model in which several submodules may be executed independently in parallel. Below is the simplest example I could come up with:
class myModule(nn.Module): def __init__(self, size_x, size_y=5): super(myModule, self).__init__() self.f0 = nn.Linear(size_x, size_y) self.f1 = nn.Linear(size_x, size_y) self.g = nn.Linear(2*size_y, size_y) self.h0 = nn.Linear(2*size_y, size_x) self.h1 = nn.Linear(2*size_y, size_x) def forward(self, x): y0 = self.f0(x[:,0,:]) y1 = self.f1(x[:,1,:]) y = torch.cat((y0,y1),dim=1) yg = self.g(y) out = torch.zeros(x.shape) out[:,0,:] = self.h0(torch.cat((y0,yg),dim=1)) out[:,1,:] = self.h1(torch.cat((y1,yg),dim=1)) return x
The functions f0 and f1 act independently on separate channels of the input. Function g combines the outputs of f0 and f1 into a single vector fg which is then used by the h functions (along with the output of the corresponding f function) to attempt to recreate the input. I’ve just used linear layers as submodules in the example, but in my application, these are more complicated, are large in their own right, and there are more than 2 of them.
My goal is to execute f0 and f1 in parallel as well as h0 and h1 in parallel. I’m OK with blocking while g executes. I’m particularly interested in how to make sure that the backward pass executes in parallel.