Custom gradient averaging with DDP?

DDP averages gradients by dividing by world size. Is there any mechanism (current or planned) to run a user-defined function to scale gradients instead of the default DDP behavior?

In my case, I have variable and uneven batch sizes on each replica and need to compute the average by global batch size.

The default for DDP is to use allreduce and average the gradients using the world size. You can choose to allreduce the gradients yourself as specify the reduction op yourself, but in this case you lose out on the better perf of DDP and allreduce will still not accept a custom UDF for the reduction op (though you can use the SUM op and average gradients however you want).

On a slightly different note, the DDP join API can help you better handle uneven inputs across ranks.

1 Like

Yeah, I know I could just write my own gradient averaging but I lose out on the perf wins. And it seems kind of silly to maintain a fork of DDP just to change the reduction from mean to sum. I’ve looked at the join method on DDP but it doesn’t apply to my issue. Thanks for the reply!

Sounds good! We are also working on a broader effort to make DDP more modular and configurable across the board. It seems like this definitely an interesting direction to support - feel free to make an issue on our GitHub repo or add to this RFC tracking the DDP configurability effort: [RFC] Modularize DistributedDataParallel · Issue #37002 · pytorch/pytorch · GitHub

We have landed gradient compression communication hooks to be used with pytorch DDP in 1.8: DDP Communication Hooks — PyTorch 1.8.1 documentation. This seems like it should support your use case quite nicely, as you would be able to run a custom UDF for the gradient reduction instead of the allreduce.