Batch-frequency optim scheduler and step order for torch>=1.1.0

Hi, I’ve set up my code to use either epoch-frequency schedulers like MultiStepLR or batch-frequency schedulers like OneCycleLR. In both cases optimizer.step() happens before scheduler.step(). However, when I use a batch-frequency scheduler, I get the warning advising me that the ordering is wrong:

UserWarning: Detected call of lr_scheduler.step() before optimizer.step(). In PyTorch 1.1.0 and later, you should call them in the opposite order: optimizer.step() before lr_scheduler.step(). Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at torch.optim — PyTorch 2.1 documentation

Is there a chance this is a false warning in the batch-scheduler case, or is there still likely to be something incorrect with my code?

There might always be a bug, we haven’t found yet, but I haven’t seen this error in the latest release yet, so I guess your code might be using the scheduler and optimizer in a wrong order accidentally.
Could you post a minimal, executable code snippet which shows this warning?
Also, are you using automatic mixed-precision training? If so, note that the optimizer.step() call can be skipped, if invalid gradients are found and the GradScaler needs to reduce the scaling factor.

1 Like

Hi, thanks for your reply. I am also using AMP (also using SWA). Here’s the skeleton of what I’m doing:

model = create_model(args)
model = nn.DataParallel(model)
swa_model = torch.optim.swa_utils.AveragedModel(model)
scaler = torch.cuda.amp.GradScaler()
optimizer = torch.optim.AdamW(model.parameters(), lr=7e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                max_lr=7e-2,
                                epochs=5, steps_per_epoch=len(dataloader), 
                                pct_start=0.3)  
swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, 
                                            swa_lr=0.5, 
                                            anneal_epochs=5)
num_epochs = 20
swa_start = 15

for epoch in (1, range(num_epochs)+1):
    for imgs,labels in dataloader:

        with torch.cuda.amp.autocast():
            output = model(imgs)
            loss = loss_fn(output, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        if epoch < swa_start:
            scheduler.step()
            
    if epoch >= swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()

In that case amp can skip the optimizer.step(), if invalid gradients are found and you should thus also skip the scheduler.step().
As a workaround, you could check the scale value via scaler.get_scale() and skip the scheduler.step(), if it was decreased. We’ll provide a better API for it so that you could directly check, if the optimizer.step() was skipped.

1 Like

Thanks, something like this?

for epoch in range(epochs):
    for batch in dataloader:
        start_scale = scaler.get_scale()
        # do stuff
        end_scale = scaler.get_scale()
        if not end_scale < start_scale:
            scheduler.step()

I’m not quite sure if I follow where to place the checks. If I put the start_scale check at the beginning of each batch loop and the end_scale check at the very end, for nearly every batch I get:
start_scale == 65536.0
end_scale == 32.0
And so it would never get to the scheduler.step()

Edit: Actually by the start of epoch 2 it looks like it settles so that both start_scale and end_scale == 32.0

Should I similarly suppress the swa_scheduler step or is that left alone?

The scale factor should balance itself after a while and not decrease in each iteration. A value of 32.0 sounds reasonable. The end_scale should be extracted after the scaler.step() operation.

What is swa_scheduler exactly doing? If it’s (re-)creating the “averaged” model using the updated parameters, I think you should also skip it, since the model wasn’t updated.

1 Like