State dictonary holds references to internal state of network

When getting a state dict using <module>.state_dict() the dictionary references the internal parameters of the model. Meaning, once the model changes, the dict will also change. Usually this doesn’t really impact things as most people will serialize the state to disk straight away.
If you however keep copies of the state dict in memory you won’t be able to load from these as their state is always the same as the networks state.

I ran into this while implementing early stopping and it took me a while to figure out. Loading the state_dict, using load_state_dict, (obviously) just had no effect.

In the interest of making the solution more discoverable I figured I’d describe my troubles here.

Do you think this would warrant a mention in the official documentation? If so, should I just create a Github issue?

Since the state_dict keeps references to your model’s parameters, you should .clone it, if you need to restore it. As this will use more memory, it’s not the default behavior.

Maybe it’s a good idea to mention it in the docs.

Seeing that what is returned is an OrderedDict (at least in 0.4.0) you can’t just call .clone() on the dict but would either have to do it for each layer:
{k : v.clone() for k, v in net.state_dict().items()} or use deep copy.deepcopy.

You are right. Sorry for being not clear enough.
Here is another small code example:

old_state_dict = {}
for key in model.state_dict():
    old_state_dict[key] = model.state_dict()[key].clone()
1 Like