Cyclic Learning rate - How to use

I am using torch.optim.lr_scheduler.CyclicLR as shown below

optimizer = optim.SGD(model.parameters(),lr=1e-2,momentum=0.9)
optimizer.zero_grad()
scheduler = optim.lr_scheduler.CyclicLR(optimizer,base_lr=1e-3,max_lr=1e-2,step_size_up=2000)
for epoch in range(epochs):
       for batch in train_loader:
                X_train = inputs['image'].cuda()
                y_train = inputs['label'].cuda()
                y_pred = model.forward(X_train)
                loss = loss_fn(y_train,y_pred)
                with amp.scale_loss(loss,optimizer) as scaled_loss:
                    scaled_loss.backward()                         # backprop
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

as cyclic lr updates learning rate after every batch, I have used scheduler.step inside the batch loop. But when I do this I get this warning :

/opt/conda/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:73: UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

Can I know why this happens ?

Your usage of the scheduler is generally fine and the warning is thrown, as we are patching the optimizer.step method in a similar way in apex as is done by the scheduler in this line of code.

To avoid this warning, initialize the scheduler after running amp.initialize(model, optimizer, opt_level).

Also, if you want, you could also add this check to avoid changing the learning rate, if the optimization step was skipped due to a gradient overflow:

    optimizer.step()
    if amp._amp_state.loss_scalers[0]._unskipped != 0: # assuming you are using a single optimizer
        scheduler.step()
1 Like

I want to ask about tracking the progress of training, Are you plot loss value each epoch or each batch"iteration" ?

As far as I remember, I was tracking the loss value each epoch. Can I know how this is relevant to the query I asked.