Loss.backward() logic update!

I want to try communication primitives other than the Allreduce that PyTorch distributed uses currently. I am trying to update the communication logic that is overlapped with backward computation. What file do I need to make the changes to implement this change? Thanks.

If you are talking about PyTorch DistributedDataParallel API, here is the entry point pytorch/distributed.py at master · pytorch/pytorch · GitHub

meanwhile DDP comm hook is a good way to overide DDP common methods DDP Communication Hooks — PyTorch 2.0 documentation