How to scale/warmup the learning rate for large batch size?

I am trying to run ImageNet training on a large number of GPUs (<64) with the help of PyTorch DDP and a batch size of 64 per GPU. I am unsure how to scale and warm-up the learning rate:

  • the original PyTorch DDP ImageNet example does not scale the learning rate at all and only decays it every 30 epochs
  • the DALI dataloader with PyTorch DDP implementation scales the learning rate with the number of workers (in relation to a base batch size 256 and also uses 5 epochs of warm-up)

However, both cases fail to reach a validation accuracy < 70% when trained with a global batch size larger than 4096 in my case. As a comparison, Horovod reaches ~74% validation accuracy out of the box up to a global batch size 32k using the exact same lr schedule as in the DALI example. How do I need to tweak the LR for PyTorch to work in this case?

@caesar025 thanks for posting!

There’s some previous discussions about how to adjust learning rate when scaling up batch size, did you try it already? Should we split batch_size according to ngpu_per_node when DistributedDataparallel - #19 by junb

I was already scaling the learning rate with the number of workers, so that was not the issue. My mistake was in the warm-up of the learning rate. As I figured the correct way to do this is:

    if epoch < args.warmup_epochs:
        lr = lr*float(1 + step + epoch*len_epoch)/(args.warmup_epochs*len_epoch)

where len(epoch) = len(train_loader). With this fix I get ~74 validation accuracy for a batch size 32k, so everything good now!

2 Likes