Learning Rate Scheduler Not Working as Expected

I tried to implement a learning rate scheduler using StepLR on Pytorch using the instructions provided. This is my code:

        optimizer = optim.SGD(model.parameters(), lr=LR, weight_decay=decay, momentum=momentum, dampening=dampening)
        scheduler = StepLR(optimizer, step_size=2, gamma=0.1)
        trainset = TrainDataset(train, trainlabels)
        train_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=batch_size,
                shuffle=True,
                )
        for epoch in range(0, 6):
            print('Epoch: ', epoch, ', LR: ', scheduler.get_lr())
            running_loss = 0
            batches = 0
            for inputs, labels in train_loader:
                batches = batches+1
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                output = model(inputs)
                output = output.squeeze()
                _, dimx, dimy = output.shape
                loss = criterion(output)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            scheduler.step()
            print('Loss :{:.4f} Epoch[{}/{}]'.format(running_loss/batches, epoch, epochs))

The learning rate scheduler does not work as expected. This is how the output:

Epoch: 0 , LR: [0.1]
Loss :0.1981 Epoch[0/6]

Epoch: 1 , LR: [0.1]
Loss :0.1957 Epoch[1/6]

Epoch: 2 , LR: [0.0010000000000000002]
Loss :0.1360 Epoch[2/6]

Epoch: 3 , LR: [0.010000000000000002]
Loss :0.1332 Epoch[3/6]

Epoch: 4 , LR: [0.00010000000000000003]
Loss :0.1293 Epoch[4/6]

Epoch: 5 , LR: [0.0010000000000000002]
Loss :0.1289 Epoch[5/6]

As you can see in Epoch 2 and Epoch 4 (the transition epochs to the new learning rate), the learning rates are actually 10x less than what it should be. Why is this happening? Have I implemented StepLR correctly?

I think the issue is with the scheduler.get_lr() when reading the implementation, it seems that you should instead use scheduler.get_last_lr().

1 Like