How to prevent default gradient synchronization when calling loss.backward()?

Hi, I’m quite new to PyTorch but I’m using DistributedDataParallel (DDP) and I’m trying to implement gradient quantization for a project, i.e. I want to compress the gradients before synchronizing them.

From exploring around a bit, I know that loss.backward() automatically synchronizes the gradients. Is there a way to prevent this behaviour and to just get the local gradients then manually synchronize them myself?

You can use backward hook to add your customized logic here before the gradient is synced.