I am implementing a model and my model + data does not fit on a single GPU. I am using DistributedDataParallel because the documentation recommends it over DataParallel.
My model appears to work now (it can overfit), but I am unsure about how I should use ReduceLROnPlateau.
Is it safe to simply call scheduler.step(validation_accuracy) regardless of rank? Or should I only call it on rank 0 and broadcast the resulting learning rate to the other processes (and how)?
I’ve currently implemented something like this, does this look correct?
dist.barrier() # Synchronize, making sure all processes have reached the end of this epoch.
acc_tensor = torch.tensor(val_acc)
dist.all_reduce(acc_tensor, op=ReduceOp.SUM)
scheduler.step(acc_tensor.item() / world_size)
In my case world size is the number of processes and the number of GPUs, so this averages the accuracy along all processes and uses it to update the scheduler.
I think this is correct, I am unsure if the barrier is needed. I assume so because otherwise some processes may be lagging behind and not have updated val_acc yet.
You’d need to initialize process group for DDP and it will synchronize workers for the first iteration.
Not sure why you are trying to call dist.all_reduce() here, normally DDP will handle model parameters for you.
Your scheduler is similar to optimizer in this context?
I think this is correct, I am unsure if the barrier is needed. I assume so because otherwise some processes may be lagging behind and not have updated val_acc yet.
The scheduler is ReduceLROnPlateau, it is used to update the learning rate based on a metric (in my case validation accuracy).
Because val_acc is not a model parameter, I would assume it to be different on every process (because every process has its own mini-batch). Therefore, I need to synchronise it so every process changes the learning rate at the same time.
When I find some time, I will try and work out a minimum example.
Do you do validation on rank 0? If so, you can compute val_acc on rank 0 and broadcast it to all ranks at which point you can run scheduler.step independently, so all ranks are consistent. note that broadcast operation will synchronize the ranks for you.