Update model's parameter with moving average in DDP

Consider a model training with DDP, all parameters are updated with gradient, except one that is only updated to keep recording the moving average of a value (generated after each forward pass) during training.

Is there an easy way to implemented it in DDP?



Do you mean a separate parameter that is not part of model.paramters() which needs to keep track of a moving average cross all workers?

If so, you should be able to just use all_reduce to get the average value of this variable across all workers, and use any typical running mean algorithm to update this on the desired worker.