Hi,
I am training a model where the learning loss is automatically updated (decay on plateau).
This works fine on single GPU.
I am porting my code to work with DDP.
After every epoch, I run a validation loop (only if local_rank == 0) and then decide whether to update or not the learning rate.
Pseudocode:
if local_rank == 0:
new_lr = run_validation_and_get_lr()
for param_group in optim.param_groups:
param_group["lr"] = new_lr
However, this seems to only update the learning rate in process 0, and not in the other processes.
How can I ensure the LR is updated in all processes ?