Detect optimizer state variable names without step

Hey folks,

I’m trying to find the state variables names an optimizer uses without stepping the optimizer.

For example, given an instance of optim.SGD, I’d like to know that it tracks momentum_buffer. Given an instance of optim.Adam, I’d like to know that it tracks step, exp_avg, and exp_avg_sq.

This is easy to get from state_dict()['state'] once a step has been taken, but initially state_dict()['state'] is empty.

Is there another way to get the state variables names?

Thanks!

Sample code:

import torch
from torch import nn
from torch import optim

class MyModel(nn.Module):
    def __init__(self) -> None:
        super(MyModel, self).__init__()
        self.layer_0 = nn.Linear(2, 3)
        self.layer_1 = nn.Linear(3, 4)

    def forward(self, x):
        return self.layer_1(self.layer_0(x))


mdl = MyModel()
opt = optim.Adam(mdl.parameters())

print(opt.state_dict())
# {'param_groups': [...], 'state': {}}

nn.MSELoss()(mdl(torch.rand(2)), torch.rand(4)).backward()
opt.step()

print(opt.state_dict())
# {'param_groups': [...], 'state': {
#     0: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])},
#     1: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])},
#     2: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])},
#     3: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])}}}

opt = optim.SGD(mdl.parameters(), lr=0.01, momentum=0.9)

print(opt.state_dict())
# {'param_groups': [...], 'state': {}}

nn.MSELoss()(mdl(torch.rand(2)), torch.rand(4)).backward()
opt.step()

print(opt.state_dict())
# {'param_groups': [...], 'state': {
#     0: {'momentum_buffer': tensor([...])},
#     1: {'momentum_buffer': tensor([...])},
#     2: {'momentum_buffer': tensor([...])},
#     3: {'momentum_buffer': tensor([...])}}}

No, I don’t think that’s possible since the states are lazily initialized as seen here in the step() function.