How to process in parallel for parallel operations in nn.Module forward pass?

Is there a way to write PyTorch nn.Module forward code so that we can pre-set certain operations to happen in parallel? Basically, make the for loops into “smart” for loops.

For example, in quite a number of scenarios like parallel convolutions, we are constantly using a for loop to sequentially go through each operation before proceeding to the next branch. Minimal code example:

.
.
.
class ParallelConv(nn.Module):
      def __init__(self):
            self.operations = nn.ModuleList([nn.Conv2d(....) for i in range(number_of_branches)])

      def forward(self, x):
            # Input 'x' is a list of tensors of individual branches/stream
            out = []
             for i, operation in enumerate(self.operations):
                  out.append(operations(x[i]))
             return out # returning also a list of tensor from each individual branch.

I do not think PyTorch will be able to optimise it during runtime to execute all the convolutions in parallel right? How can we go about doing it ourselves then?

The same question for torch jit / TorchScript, when we trace the input through the modules set up, it is actually sequentially going through each branch operation before proceeding to the next. Can TorchScript identify that the inputs are separated into individual branches during compilation?

Advice is greatly appreciated.

1 Like

Realistically, unless the operations can be executed concurrently on multiple devices, I wouldn’t expect any speedup when executing eagerly. In fact, there can be slowdowns attributed to multiple kernels contending for shared hardware resources.

If yo have multiple devices, you might check if running operations on different devices (e.g., in a model-parallel fashion) would speed things up, although this might require additional overhead in communicating the inputs across the devices first.

This is also relevant for impl. 4D convs via 3D, the for loop is an unnecessary sequential to a fully parallelizable operation. Any pointers?