Does PyTorch compute the gradients for all nodes at the same level in parallel?

Assume we have a simple network represented by the computation graph below:

Here, M1 through M4 are all nn.Module, say each being

  nn.Linear(dim1, dim2),
  nn.Linear(dim2, dim3)

Q1. Since M1 through M4 are independent and at the same level in computation graph, will backward() compute the gradients for them in parallel? (Anything else will involve sequential operations and be suboptimal.) Please share a reference from original docs if they mention this.

Q2. What happens if there are too many nodes at the same level and for some compute / memory constraints a parallel backward() cannot be executed?

Yes, I believe it will run in parallel introduced in this PR by @colesbury in 2017 but he can correct me if that’s not the case.

1 Like

I’m not sure of the current state of things. Probably @ezyang or @albanD know.

In 2017, the parallel execution was only across devices, not for individual modules on the same device.

1 Like

It is still single threaded per device.


I’m wondering why does this not make PyTorch very slow (i.e. processing modules at the same level sequentially)? Shouldn’t parallelizing this give a big speedup?

It depends a lot on how big the modules are. The most common use cases have very large modules that already saturate a single device fully when they run. So there is no opportunity to run multiple modules at the same time on the same device.

This is why we parallelize on different devices but not for a single device and this is not slowing down the common use case.