Loss.backward() logic update!

I want to try communication primitives other than the Allreduce that PyTorch distributed uses currently. I am trying to update the communication logic that is overlapped with backward computation. What file do I need to make the changes to implement this change? Thanks.

If you are talking about PyTorch DistributedDataParallel API, here is the entry point pytorch/distributed.py at master · pytorch/pytorch · GitHub

meanwhile DDP comm hook is a good way to overide DDP common methods DDP Communication Hooks — PyTorch 2.0 documentation

Hi @Yanli_Zhao! Thanks for the prompt reply. I want to implement a parameter server using the torch.distributed.send and torch.distributed.recv instead of the current decentralized approach using torch.distributed.all_reduce. Is this something that will be possible by updating the communication hooks and torch.nn.parallel.DistributedDataParallel.register_comm_hook()? If so, do you have any tutorials in mind that might be of any help? Thank you again.