How to use LR scheduler in DDP and amp


I am wondering is there any tutorials or examples about the correct usage of learning rate scheduler when training with DDP/FSDP? For example, if the LR scheduler is OneCycleLR, how should I define total number of steps in the cycle, i.e., total_steps or (steps_per_epoch and epochs) arguments of the scheduler?
The reason I am asking is that this scheduler updates LR based on the steps but in DDP the steps for current process is different from effective steps (i.e., in non-DDP, the step is 256, but in DDP with 2 GPU, the step is 128 for each process).
Also, where should I call scheduler.step() if I am also using amp? Assume my code is:

steps_per_epoch = ???
n_epoch = 10
scheduler = (optimizer, max_lr=0.01, steps_per_epoch=steps_per_epoch, epochs=n_epoch)
for it, (img, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        bs = img.size(0)

        img = img.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)

        # compute output and loss
        with autocast(enabled=args.amp, dtype=args.amp_dtype):
            outputs = model(img)
            loss = criteria(outputs, labels)

        # scheduler.step() is this position correct?

Thank you!

Yes, with DDP since the effective batch size scales by N the steps_per_epoch would effectively be steps_per_epoch/N and that is likely what you want to use.

I’m assuming you’re scaling LR proportional to N as well? A thumb rule I commonly see is lr*sqrt(N) but afaik the best scaling is an open question

Thank you for the reply!

Are you suggesting that the steps_per_epoch should be set to 128 in the example I gave for 2 GPU? In that case, since each process have gradients in terms of all batch (i.e., 256), would not setting steps_per_epoch = 128 be wrong when the optimiser tries to update the LR?

Yeah, in your example if a single GPU setup takes 256 steps to run through the entire dataset, then steps_per_epoch should be 128 for 2 GPUs.

At steps_per_epoch=128 your distributed model would have seen the entire dataset. This is the same as setting steps_per_epoch=256 in a non-distributed setting