Hey folks,
I’m trying to find the state variables names an optimizer uses without stepping the optimizer.
For example, given an instance of optim.SGD
, I’d like to know that it tracks momentum_buffer
. Given an instance of optim.Adam
, I’d like to know that it tracks step
, exp_avg
, and exp_avg_sq
.
This is easy to get from state_dict()['state']
once a step has been taken, but initially state_dict()['state']
is empty.
Is there another way to get the state variables names?
Thanks!
Sample code:
import torch
from torch import nn
from torch import optim
class MyModel(nn.Module):
def __init__(self) -> None:
super(MyModel, self).__init__()
self.layer_0 = nn.Linear(2, 3)
self.layer_1 = nn.Linear(3, 4)
def forward(self, x):
return self.layer_1(self.layer_0(x))
mdl = MyModel()
opt = optim.Adam(mdl.parameters())
print(opt.state_dict())
# {'param_groups': [...], 'state': {}}
nn.MSELoss()(mdl(torch.rand(2)), torch.rand(4)).backward()
opt.step()
print(opt.state_dict())
# {'param_groups': [...], 'state': {
# 0: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])},
# 1: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])},
# 2: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])},
# 3: {'step': ..., 'exp_avg': tensor([...]), 'exp_avg_sq': tensor([...])}}}
opt = optim.SGD(mdl.parameters(), lr=0.01, momentum=0.9)
print(opt.state_dict())
# {'param_groups': [...], 'state': {}}
nn.MSELoss()(mdl(torch.rand(2)), torch.rand(4)).backward()
opt.step()
print(opt.state_dict())
# {'param_groups': [...], 'state': {
# 0: {'momentum_buffer': tensor([...])},
# 1: {'momentum_buffer': tensor([...])},
# 2: {'momentum_buffer': tensor([...])},
# 3: {'momentum_buffer': tensor([...])}}}