When getting a state dict using <module>.state_dict() the dictionary references the internal parameters of the model. Meaning, once the model changes, the dict will also change. Usually this doesn’t really impact things as most people will serialize the state to disk straight away.
If you however keep copies of the state dict in memory you won’t be able to load from these as their state is always the same as the networks state.
I ran into this while implementing early stopping and it took me a while to figure out. Loading the state_dict, using load_state_dict, (obviously) just had no effect.
In the interest of making the solution more discoverable I figured I’d describe my troubles here.
Do you think this would warrant a mention in the official documentation? If so, should I just create a Github issue?
Since the state_dict keeps references to your model’s parameters, you should .clone it, if you need to restore it. As this will use more memory, it’s not the default behavior.
Seeing that what is returned is an OrderedDict (at least in 0.4.0) you can’t just call .clone() on the dict but would either have to do it for each layer: {k : v.clone() for k, v in net.state_dict().items()} or use deep copy.deepcopy.