ReduceLROnPlateau usage with DistributedDataParallel?

I am implementing a model and my model + data does not fit on a single GPU. I am using DistributedDataParallel because the documentation recommends it over DataParallel.

My model appears to work now (it can overfit), but I am unsure about how I should use ReduceLROnPlateau.

Is it safe to simply call scheduler.step(validation_accuracy) regardless of rank? Or should I only call it on rank 0 and broadcast the resulting learning rate to the other processes (and how)?

I’ve currently implemented something like this, does this look correct?

dist.barrier()  # Synchronize, making sure all processes have reached the end of this epoch.
acc_tensor = torch.tensor(val_acc)
dist.all_reduce(acc_tensor, op=ReduceOp.SUM)
scheduler.step(acc_tensor.item() / world_size)

In my case world size is the number of processes and the number of GPUs, so this averages the accuracy along all processes and uses it to update the scheduler.

I think this is correct, I am unsure if the barrier is needed. I assume so because otherwise some processes may be lagging behind and not have updated val_acc yet.

Hello

You’d need to initialize process group for DDP and it will synchronize workers for the first iteration.

Not sure why you are trying to call dist.all_reduce() here, normally DDP will handle model parameters for you.

Your scheduler is similar to optimizer in this context?

I think this is correct, I am unsure if the barrier is needed. I assume so because otherwise some processes may be lagging behind and not have updated val_acc yet.

You generally don’t need to synchronize your workers, e.g.
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
(barriers are only used to synchronize while checkpointing)

The scheduler is ReduceLROnPlateau, it is used to update the learning rate based on a metric (in my case validation accuracy).

Because val_acc is not a model parameter, I would assume it to be different on every process (because every process has its own mini-batch). Therefore, I need to synchronise it so every process changes the learning rate at the same time.

When I find some time, I will try and work out a minimum example.

Sorry, I think I misunderstood your question initially.

(thanks @rvarm1 for suggestion!)

Do you do validation on rank 0? If so, you can compute val_acc on rank 0 and broadcast it to all ranks at which point you can run scheduler.step independently, so all ranks are consistent. note that broadcast operation will synchronize the ranks for you.

(by @wayi)
maybe you can also consider using something like https://arxiv.org/pdf/2007.05105.pdf
and
GitHub - facebookresearch/fairscale: PyTorch extensions for high performance and large scale training.