The FSDP wrapper seems to use the regular torch.distributed
NCCL primitives, but distributed autograd requires RPC primitives instead. Is it possible to use these two in conjunction? Is there any example of this?
FSDP doesn’t support RPC primitives. Curious in what scenario do you need distributed autograd with FSDP?
It’s for a custom sharding strategy I am implementing as part of a research project – more or less a ringattention alternative.
The current approach for implementing ring-attention algorithms is to utilize send/recv and customized forward/backward functions to perform the necessary ring computations. One example is pytorch/torch/distributed/tensor/experimental/_attention.py at main · pytorch/pytorch · GitHub.
Thanks for the reference! I’ll probably end up doing it like this then, implementing manual gradient calculations with NCCL seems like less work than reimplementing FSDP in RPC.