How does pytorch handle backward pass in a multi-GPU setting? (DLRM use case)

Past works have often proposed remote GPU caching as a performance optimization. As an example, if data x originally stored on GPU0 is requested by GPU1, then x is cached in GPU1’s L1 or L2 cache (there are pros and cons if it’s cached in L1 or L2 depending on the workload).

While applying remote caching is trivial for inference, getting it to work for training seems to be more challenging. Primarily because we need to maintain coherence across all cached copies after the backward pass. I was curious to know how Pytorch actually handles all this under the hood. I’ve read that the gradients are calculated based on a computational graph.

Some questions I have:

  • in case of model parallel training, does the computational graph also store which device contains the required weights? (I think it is yes)
  • what does pytorch do if data is remotely cached (there are 2 options here: invalidate all cached copies and only update the original data or update both original and cached copies)

Any pointers to find the answers to these questions would be great. Thanks!

PS: I was specifically trying to find out how the gradient updates for embedding tables occur for DLRM How does pytorch handle backward pass in a multi-GPU setting? · Issue #353 · facebookresearch/dlrm · GitHub