lr_schedular.StepLR() seems to be changing the LR wrongly in the first .step() call

I was running this very simple (toy) code to observe the change in LR with epochs when there is a schedular. I used the simple StepLR().

import torch
print('PyTorch version: {0}'.format(torch.__version__))

model = Net(n_in, n_out) # 'Net' is a simple MLP
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
schedular = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

print('Initial LR : {0:.8f}'.format(schedular.get_lr()[0]))
for e in range(8):
    schedular.step()
    print('Epoch {0}, LR: {1:.8f}'.format(e, schedular.get_lr()[0]))
    for i in range(5):
        optimizer.zero_grad()

        y = model(x) # x is proper sized input tensor
        l = y.mean()
        l.backward()
        
        optimizer.step()

When I ran this code with PyTorch 1.1.0, it produced

PyTorch version: 1.1.0
Initial LR : 0.01000000
Epoch 0, LR: 0.00010000
Epoch 1, LR: 0.00001000
Epoch 2, LR: 0.00000100
Epoch 3, LR: 0.00000010
Epoch 4, LR: 0.00000001
Epoch 5, LR: 0.00000000
Epoch 6, LR: 0.00000000
Epoch 7, LR: 0.00000000

My question is, why is the first invocation of schedular.step() multiplies the LR with 0.01 and not 0.1 (i.e., gamma) ? Although, the later calls to .step() are okay.

When I went further to investigate, I found out that this behaviour is different in PyTorch version 0.4.0 which produced this output. This seems to be correct.

PyTorch version: 0.4.0
Initial LR : 0.10000000
Epoch 0, LR: 0.01000000
Epoch 1, LR: 0.00100000
Epoch 2, LR: 0.00010000
Epoch 3, LR: 0.00001000
Epoch 4, LR: 0.00000100
Epoch 5, LR: 0.00000010
Epoch 6, LR: 0.00000001
Epoch 7, LR: 0.00000000

This seems to be related to this issue.
If you print the optimizer’s learning rate, you should see that it’s behaving as expected:

model = nn.Linear(1, 1) # 'Net' is a simple MLP
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
schedular = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

print('Initial LR : {0:.8f}'.format(schedular.get_lr()[0]))
for e in range(8):
    schedular.step()
    print('Epoch {0}, LR: {1:.8f}, opt LR {2:.8f}'.format(e, schedular.get_lr()[0],
          optimizer.param_groups[0]['lr']))
1 Like

Dear @ptrblck

I ran this code on colab and the output is not consistent. link to colab

import torch
print(“pytorch version”,torch. version )
import torch.nn as nn
model = nn.Linear(1, 1) # ‘Net’ is a simple MLP
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
schedular = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [4,7], gamma=0.1)

print(‘Initial LR : {0:.8f}’.format(schedular.get_lr()[0]))
for e in range(8):
schedular.step()
print(‘Epoch {0}, LR: {1:.8f}, opt LR {2:.8f}’.format(e, schedular.get_lr()[0],
optimizer.param_groups[0][‘lr’]))

pytorch version 1.4.0
Initial LR : 0.10000000
Epoch 0, LR: 0.10000000, opt LR 0.10000000
Epoch 1, LR: 0.10000000, opt LR 0.10000000
Epoch 2, LR: 0.10000000, opt LR 0.10000000
Epoch 3, LR: 0.00100000, opt LR 0.01000000
Epoch 4, LR: 0.01000000, opt LR 0.01000000
Epoch 5, LR: 0.01000000, opt LR 0.01000000
Epoch 6, LR: 0.00010000, opt LR 0.00100000
Epoch 7, LR: 0.00100000, opt LR 0.00100000

even this is inconsistent at Google colab. i am receiving output with step_size = 2 as,

pytorch version 1.4.0
Initial LR : 0.10000000
Epoch 0, LR: 0.10000000, opt LR 0.10000000
Epoch 1, LR: 0.00100000, opt LR 0.01000000
Epoch 2, LR: 0.01000000, opt LR 0.01000000
Epoch 3, LR: 0.00010000, opt LR 0.00100000
Epoch 4, LR: 0.00100000, opt LR 0.00100000
Epoch 5, LR: 0.00001000, opt LR 0.00010000
Epoch 6, LR: 0.00010000, opt LR 0.00010000
Epoch 7, LR: 0.00000100, opt LR 0.00001000

You should use get_last_lr() in pytorch 1.4 instead.
refer to https://github.com/pytorch/pytorch/pull/26423

2 Likes