Valid loss dependent LR scheduling and DDP


I am training a model where the learning loss is automatically updated (decay on plateau).
This works fine on single GPU.

I am porting my code to work with DDP.
After every epoch, I run a validation loop (only if local_rank == 0) and then decide whether to update or not the learning rate.


if local_rank == 0:
    new_lr = run_validation_and_get_lr()
    for param_group in optim.param_groups:
                param_group["lr"] = new_lr

However, this seems to only update the learning rate in process 0, and not in the other processes.

How can I ensure the LR is updated in all processes ?


I’m guessing this is because your update is conditioned on if local_rank == 0. If you only want to run this validation on rank 0, you can broadcast the result to the rest of the ranks as follows:

if local_rank == 0:
    new_lr = compute_new_lr()
    dist.broadcast(torch.tensor(new_lr), src=0)
    new_lr = torch.zeros()
    dist.broadcast(new_lr, src=0)
# all ranks set param_group[lr] to new_lr.item()