Is the following a correct way to implement linear warmup?

for epoch in range(args.start_epoch, args.epochs + args.warmup_epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if args.warmup_epochs:
            lr = args.warmup_lr + (args.lr - args.warmup_lr) * (epoch / args.warmup_epochs)
            if epoch <= args.warmup_epochs:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

        
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, scaler, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch - args.warmup_epochs,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)

        if epoch >= args.warmup_epochs:
            scheduler.step()

Most of the implementations I found couple warmup with scheduler. I am trying to implement one that decouples them. Is the above implementation correct? Or is there a better way to implement?

looks good, but perhaps you’d need to also save scheduler.state_dict() to correctly resume training (though scheduler construction with last_epoch=epoch should be enough for most schedulers, I think)