Get current LR of optimizer with adaptive LR

How can I get the current learning rate being used by my optimizer?

Many of the optimizers in the torch.optim class use variable learning rates. You can provide an initial one, but they should change depending on the data. I would like to be able to check the current rate being used at any given time.

This question is basically a duplicate of this one, but I don’t think that one was very satisfactorily answered. Using Adam, for example, when I print:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

'''...training...'''

for param_group in optimizer.param_groups:
    print(param_group['lr'])

I always see the initial learning rate no matter how many epochs of data I run.

However, if I want to restart interrupted training progress or even just debug my loss, it makes sense to know where the optimizer left off.

How can I do this?

2 Likes

You could save the optimizer state dict. Something like this suggested here: https://github.com/pytorch/pytorch/issues/2830#issuecomment-336194949 should work.

After that you could just load it back.

I am guessing the reason you see the same lr is that the lr for adam has not changed. The effective lr which has components from moment estimate are different.

2 Likes

That makes sense; thanks for the link. Is there a way to show the effective learning rate, short of calculating it on my own?

1 Like

I am sorry, but I am unaware of any direct way.

1 Like

for most optim all layers use the same lr, so u can just do:

print(optimizer.param_groups[0]['lr'])

If you’re using a lr_scheduler u can do the same, or use:

print(lr_scheduler.get_lr())
4 Likes

Nit: get_lr() might not yield the current learning rate, so you should use get_last_lr(). :wink:

17 Likes

This is important! Otherwise, it misleads. More info in this pr

Maybe you could add some additional methods for the optimizer instance:

# -*- coding: utf-8 -*-
# @Time    : 2020/12/19
# @Author  : Lart Pang
# @GitHub  : https://github.com/lartpang
import types

from torch import nn
from torch.optim import SGD, Adam, AdamW


class OptimizerConstructor:
    def __init__(self, model, initial_lr, mode, group_mode, cfg):
        """
        A wrapper of the optimizer.

        :param model: nn.Module
        :param initial_lr: int
        :param mode: str
        :param group_mode: str
        :param cfg: A dict corresponding to your ``mode`` except the ``params`` and ``lr``.
        """

        self.mode = mode
        self.initial_lr = initial_lr
        self.group_mode = group_mode
        self.cfg = cfg
        self.params = self.group_params(model)

    def construct_optimizer(self):
        if self.mode == "sgd":
            optimizer = SGD(params=self.params, lr=self.initial_lr, **self.cfg)
        elif self.mode == "adamw":
            optimizer = AdamW(params=self.params, lr=self.initial_lr, **self.cfg)
        elif self.mode == "adam":
            optimizer = Adam(params=self.params, lr=self.initial_lr, **self.cfg)
        else:
            raise NotImplementedError
        return optimizer

    def group_params(self, model):
        ....

    def __call__(self):
        optimizer = self.construct_optimizer()
        optimizer.lr_groups = types.MethodType(get_lr_groups, optimizer)
        optimizer.lr_string = types.MethodType(get_lr_strings, optimizer)
        return optimizer

def get_lr_groups(self):
    return [group["lr"] for group in self.param_groups]

def get_lr_strings(self):
    return ",".join([f"{group['lr']:10.3e}" for group in self.param_groups])