Parallelize architecture for Predictive Coding architecture

Hi everyone,
my team and I are developing a framework to train pytorch models through predictive coding instead of relying on backpropagation. One of the main advantages of doing this is that every layer can run independently and in parallel with the others. However we are struggling to understand how to implement this true parallelization using Pytorch.
The basic architecture is composed by several linear layers, each of which can perform its forward pass, compute their error and backpropagate it. The only synchronization step required is when computing the error (as every layer requires the data from the previous one). In pseudocode what we are aiming is the following:

def forward_and_backward(x):
    # Run in parallel:
    layer1.forward() # saves the output in layer1.output

   # Run in parallel

   # Run in parallel

The following pictures shows the approach clearly:

First we compute all the black arrows (forward), then the green ones (loss) and then the red ones (backward). I put two energies so I should be able to compute each x loss from the current layer and from the next layer in two steps (the dotted arrows represent detached copies of tensors) by calling first backward on all E and then on all E’. The only approach I’ve tried is with cuda streams but without getting any gains. Right now I’m simply experimenting with a fully connected nn with x independent layers that I try to run in parallel. However doesn’t matter the layers’ width or batch size I always get a linear growing in x for the execution time.

Any help on how to achieve this would be incredibly appreciated as we are starting to do heavy research on Predictive Coding and we could speed up our training/inference time by at least 10x by managing to do this. (of course also the optim step should be parallelized if it’s not already in the normal pytorch framework)

PyTorch will add the CUDA kernels to the queue and will launch them in parallel if enough resources are available. This post describes this behavior in more detail with some examples.
Since you are using linear layers, the matmul operation would most likely use the entire compute resources and thus no parallel execution might be possible.

Thank you for your reply.

If what you say it’s true I would expect the computation time to be linear also given the batch size, however there’s no difference in execution time between 32 or 1024 (as well in the hidden_dim of each of the 10 layers I’m using: from 256, to 1024). I imagine that there’s a lot of overhead in calling stuff, compared to executing it. In the end what I’m trying to do is computing a single linear layer which for abstraction is split in 10 that should run in parallel. I’m sure that if I build a model with 10 small layers or one with 1 big layer the performance would be very different. I’ve tried to look into things such as vmap (but doesn’t work because it doesn’t support indexing different layers) or torchscript.fork (which helps a little, 15-25%, but introduces a lot of overhead for the forking since I have very simple and small layers) but without any luck.

This is how my forward pass and backward pass for each layer look like:
(the second half in the forward pass is the loss (energy) computation, which happens in every layer)

I have backward pass after I collected the sum of all the individual losses, but in theory it could be called independently on each single one (I tried but it is still executed sequentially, not even in parallel with the forward pass of the next layer).

I don’t know if it can help in giving me tips to improve performance.

Thank you so much for your time.