Gradient accumulation for 'DataParallel'

The document about ‘DataParallel’ says “During the backwards pass, gradients from each replica are summed into the original module”. I am wondering how is this implemented to accumulate the gradients from different device ? I have read the code but still get confused.

1 Like

the function replicate is part of the computational graph.
So there is an AllReduce function in the backward of replicate, the the gradients are reduced here