How does pytorch handle backward pass in a multi-GPU setting?

Past works have often proposed remote GPU caching as a performance optimization. As an example, if data x originally stored on GPU0 is requested by GPU1, then x is cached in GPU1’s L1 or L2 cache (there are pros and cons if it’s cached in L1 or L2 depending on the workload).

While applying remote caching is trivial for inference, getting it to work for training seems to be more challenging. Primarily because we need to maintain coherence across all cached copies after the backward pass. I was curious to know how Pytorch actually handles all this under the hood. I’ve read that the gradients are calculated based on a computational graph.

Some questions I have:

  • in case of model parallel training, does the computational graph also store which device contains the required weights? (I think it is yes)
  • what does pytorch do if data is remotely cached (there are 2 options here: invalidate all cached copies and only update the original data or update both original and cached copies)

Any pointers to find the answers to these questions would be great. Thanks!

PS: I was specifically trying to find out how the gradient updates for embedding tables occur for DLRM How does pytorch handle backward pass in a multi-GPU setting? · Issue #353 · facebookresearch/dlrm · GitHub

Hey @christindbose

in case of model parallel training, does the computational graph also store which device contains the required weights? (I think it is yes)

This depends on which training paradigm you are using:

Local: PyTorch eager mode builds the autograd graph during forward pass. If the forward pass uses multiple devices on the same machine, there will be copy operators representing the Device-to-Device tensor copy. Those copy operators will be recorded in the autograd graph. So yep, in this case, the graph has the device information.

SPMD: For Single-Program Multi-Data style distributed/parallel training, usually each GPU device is kind-of “independent” to other GPUs in terms of forward/backward computations. Different training paradigms, such as DDP/FSDP will use different communications to synchronize those devices. DDP AllReduces gradients after backward computation, while FSDP AllGather full parameter tensors before computation and ReduceScatter gradients after backward computation. So, in this case, the autograd graph does not records remote device information. Instead, it relies on ProcessGroup and NCCL/Gloo comm libraries to talk to other devices.

MPMD: For Multi-Program Multi-Data style distributed training (e.g., PyTorch RPC), there will be a distributed autograd graph that spans multiple processes and remote GPUs. It attempts to treat a remote GPU as-if it’s a local GPU.

what does pytorch do if data is remotely cached (there are 2 options here: invalidate all cached copies and only update the original data or update both original and cached copies)

If you are referring to input data, this is usually handled by data loaders. As input is read-only, we don’t need to worry about cached copies.

If you are referring to model parameters, there are two different styles:

  1. Synchronous: in every iteration, all trainers will read the param from previous iteration, generate gradients locally, synchronize gradients globally, and then update parameter. So there won’t be staleness here.
  2. Asynchronous: each trainer might proceed at its own pace, with some optional global staleness bound. This is an open research area. So, no fixed answer for this one. :slight_smile:

Also GPU L1 and L2 caches are really small, they are usually not large enough to even hold an individual tensor. It’s usually up to operator kernel implementation to decide how to utilize shared memory and global memory from GPU. Every kernel will read input from CUDA global memory and write output to CUDA global memory. If we treat kernel design and distributed training design as two different domains, L1/L2 cache optimization might belong to the former.

1 Like

Hi Shen Li @mrshenli,

Thanks a lot for your detailed answer. Very helpful.

I was particularly interested in optimizing a single-node multi-device run for a DLRM model. As you may already know, the embedding layers for DLRM are model parallel (ie split across many GPUs) while the fully connected layers are trained in a data-parallel manner. I believe, this should fall under the ‘Local’ training paradigm in your above description.

In order to reduce the inter-GPU communication, one optimization for future GPU architectures is to cache the frequently used embedding parameters in order to reduce off-chip memory access. As you already mentioned since the L1/L2 are very small in comparison to the data locality shown by these workloads, past research works have explored devoting a portion of the GPU HBM to cache model parameters (note that while the HBM still does not provide the bandwidth of an L1/L2, this is still better than accessing data remotely through off-chip links such as NVlink or PCIe).

So my question just becomes: how are the gradient synchronization performed in this scenario (assuming a synchronous training approach)? Does libraries such as NCCL perform something like a broadcast operation on the gradients here?

Any advice or pointers to find these answers would be helpful.

Thanks,
Christin