Updating Optimizer Hyperparameters

Just want to check if I am doing this correctly because, when I print the optimizer.defaults, they are not updated.

dummy_model=nn.Linear(10, 2)

optimizer=torch.optim.SGD(dummy_model.parameters(), lr=0.01, momentum=0.9, nesterov=True)

#update the momentum to a new value:
optimizer.param_groups[0]['momentum']=0.8

print(optimizer.param_groups[0])

print(optimizer.defaults)

The first print statement shows the expected momentum at 0.8. The second still shows 0.9. So will this produce the desired effect of changing the momentum for training?

Hi @J_Johnson,

If you want to change the momentum constant for all values, you need to apply it over all groups not just the first one. So,

for group in optimizer.param_groups:
  group['momentum']=0.8

That’s an interesting point you’ve raised, which begs another question. I’ve never seen more than 1 group in optimizer.param_groups. Do you know of any example which would produce more than one group? Just curious.

Ah, ignore what I said (although that can be applied to more complicated models). The .defaults attribute is basically a dict of default values your network has when it’s initialized, there’s a bit more discussed in the comments of the source torch.optim.Optimizer class here.

When you update via,

for group in optimizer.param_groups:
  group['momentum']=0.8

it will update and overwrite the default value.

I’ve always used for group in optimizer.param_groups over optimizer.param_groups[0] as that’s the same way it’s done within the base class, and covers optimizers with more than one group. In the documentation (shown here) there’s an example with more than one group.

Okay, so the param_groups, though, is the part I need to update in order to update optimizer parameters between training epochs.

By the way, there is also an:

optimizer.defaults['momentum']=0.8

which can be used to update the defaults. But perhaps the defaults are kept if you want to log what the starting optimizer values were, though, so that would be pointless to update.