Basic question to source code optim.optimizer

Hi all,

I am trying to understand the source code of torch.optim.sgd (link to source). As it inherits several features from the torch.optim.optimizer class (link to source), I am also taking glances at that. For the beginning, I have two questions:

  1. In the constructor of optimizer, it says self.state = defaultdict(dict). As far as I know, dict has to be a function that is called when trying to access a key in the state dictionary that is not present. Where is the dict function defined?

  2. The method in the SGD source code:

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

It calls __setstate__ from the optimizer parent which simply does self.__dict__.update(state). So I assume this is to load an optimizer with a previous setting? Then what is the purpose of setting all the ‘nesterov’ arguments to False?


  1. defaultdict accepts the default_factory argument, which is the Python dict class in this case. I’m not 100% sure, but think this was added to return an empty dict for missing keys:
state = defaultdict()
> defaultdict(None, {})
> KeyError: 'lala'

state = defaultdict(dict)
> defaultdict(<class 'dict'>, {})
> {}
  1. I don’t know exactly why this is needed. I assume thr __getstate__ method of the parent class is used and since SGD has additional parameters, they would need to be reset to the defaults in the corresponding __setstate__.
1 Like

Thanks ptrblck!

  1. Makes sense!

  2. Right, since not all optimizers have “nestorov” argument but it is evaluated in the beginning of SGD step( ), the setstate should ensure that the argument “nestorov” is set. However, what about the other parameters like “dampening”? It is also evaluated at the beginning of step( ), but not set in the setstate function. This would imply that all optimizers definitely have this argument…