The document about ‘DataParallel’ says “During the backwards pass, gradients from each replica are summed into the original module”. I am wondering how is this implemented to accumulate the gradients from different device ? I have read the code but still get confused.
1 Like
the function replicate
is part of the computational graph.
So there is an AllReduce
function in the backward of replicate
, the the gradients are reduced here