"Look-forward" in the execution graph to setup communication during distributed training?

Hello,

I’m working on the communication for a pipeline-parallelism (PP) training setup where the nn.Modules of a model could be located on arbitrary devices/ranks connected in an arbitrary way that could change during the execution.

In order to setup the communication between the different nn.Modules I would need additional information in the forward (and backward) computation of the following nn.Modules.

Example:
A nn.Module executed the forward method and then needs to know where it should send its results, for that it would need additional information like the rank, etc., of the next nn.Module(s) to setup the point-to-point communication.

One way to do that (afaik) is to get the execution graph beforehand with torch.fx, but is there a way to do that at end of every nn.Module so it can adapt to changes in the execution?

I’m happy to hear suggestions, tips, tricks, for this scenario.

Thank you :slight_smile:

Thanks for your question, can you kindly provide an code example here? Thanks!

Upon checking with my teammates, distributed point-2-point communication can be in this scenario and RPC can also be used here. cc: @H-Huang

Hi Hugo,

thank you very much for your reply. :slight_smile:

What we are trying to setup is similar to what is outlined in this paper for the pipeline parallel part:
Amazon SageMaker Model Parallelism: A General and Flexible Framework for Large Model Training in section “4.1 Overview” but instead of using some kind of “execution server” we want to set this up with PyTorch itself.

So in the best case we would be able to trigger a point-to-point communication when data needs to be exchanged between two nn.Modules if they are on two different devices.

Please tell me if I should explain some parts better etc. :slight_smile:

Hi Michael, you should take a look at the Distributed RPC Framework in PyTorch. It natively provides features such as optimized tensor communications, remote memory management, and distributed autograd. You can send the tensors of one module to another using remote references. Tensors and gradients sent over RPC during a forward pass are recorded and we use this information to perform a distributed backward pass using RPC. One small note is remote references do not currently support CUDA tensors so if the modules are on different devices you would need to transfer to CPU then use the APIs.

See an example of distributed pipeline parallelism using RPC in the pytorch tutorials!