Parallel on multiple process but only do optimization step on process 0

I am trying to implement parallelization as follows but not sure if it is possible.
For example, train data with multiple processes (CPU cores). Have each process deal with independent batches. Instead of taking the optimization step independently for each batch, I want to gather loss and gradient from all processes and only take optimization step on Process 0.
Is that possible to do that with torch.distributed package?

Currently, I followed instructions from https://pytorch.org/tutorials/intermediate/dist_tuto.html. And have it work.

Hey @stclaireva

Is that possible to do that with torch.distributed package?

This is possible. You can

  1. Run forward-backward locally.
  2. Use all_gather or all_gather_coalesced to collect all gradients into rank 0.
  3. Manually add/average those gradients into param.grad on rank 0.
  4. Run optimizer.step() on rank 0 to update parameters.
  5. Let rank 0 broadcast the updated parameters to other ranks.
  6. go to step 1 to start a new iteration

Curious, why do you need the above algorithm instead of letting DistributedDataParallel handle it for you?