Gradient Aggregation in PyTorch

a simple MLP model was considered and the initial model() shared between workers.
Conventionally, each worker train model based on its own dataset and shared the resulted model. By averaging models from different workers, the new model achieved for the next round of algorithms.
How efficiently share gradients of different workers and averaged them and use the “optimizer .step()” to update the initial model(initial model shared among workers)?

hey @Ohm, if you are using DistributedDataParallel, you can try the no_sync context manager. For example, you can wrap local training iterations with no_sync. When you need to do gradient averaging, just run one fw-bw out of the no_sync context, and DDP should be able to take care of the gradient synchronization.

Another option would be building your application using torch.distributed.rpc and then use a parameter server to sync models. See this tutorial.

If all parameters are dense, the DDP solution should be more efficient. If there are sparse parameters, the parameter server solution might be faster.