Is there a way to write PyTorch nn.Module forward code so that we can pre-set certain operations to happen in parallel? Basically, make the for loops
into “smart” for loops
.
For example, in quite a number of scenarios like parallel convolutions, we are constantly using a for loop
to sequentially go through each operation before proceeding to the next branch. Minimal code example:
.
.
.
class ParallelConv(nn.Module):
def __init__(self):
self.operations = nn.ModuleList([nn.Conv2d(....) for i in range(number_of_branches)])
def forward(self, x):
# Input 'x' is a list of tensors of individual branches/stream
out = []
for i, operation in enumerate(self.operations):
out.append(operations(x[i]))
return out # returning also a list of tensor from each individual branch.
I do not think PyTorch will be able to optimise it during runtime to execute all the convolutions in parallel right? How can we go about doing it ourselves then?
The same question for torch jit
/ TorchScript
, when we trace the input through the modules set up, it is actually sequentially going through each branch operation before proceeding to the next. Can TorchScript
identify that the inputs are separated into individual branches during compilation?
Advice is greatly appreciated.