Learning rate decay during training

I am trying to implement a particular learning rate decay on the Adam optimizer with each training step (global step) according to the function below:

from math import pow


init_learning_rate = 1e-3
min_learning_rate = 1e-5
decay_rate = 0.9999

lr = ((init_learning_rate - min_learning_rate) *
      pow(decay_rate, global_step) +
      min_learning_rate)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

Would I have to redefine the optimizer at every time step?

import torch
import torchvision
from math import pow

model = torchvision.models.AlexNet()

def lr_decay(global_step,
    init_learning_rate = 1e-3,
    min_learning_rate = 1e-5,
    decay_rate = 0.9999):
    lr = ((init_learning_rate - min_learning_rate) *
          pow(decay_rate, global_step) +
          min_learning_rate)
    return lr

lr0 = lr_decay(0)
optimizer = torch.optim.Adam(model.parameters(), lr=lr0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
    lr_lambda=lambda step: lr_decay(step)/lr0)
for step in range(10):
    print(step, optimizer.param_groups[0]['lr'], lr_decay(step))
    assert optimizer.param_groups[0]['lr'] == lr_decay(step)
    # loss.backward()
    optimizer.step()
    scheduler.step()
2 Likes