The document about ‘DataParallel’ says “During the backwards pass, gradients from each replica are summed into the original module”. I am wondering how is this implemented to accumulate the gradients from different device ? I have read the code but still get confused.
replicate is part of the computational graph.
So there is an
AllReduce function in the backward of
replicate, the the gradients are reduced here