Does model.load_state_dict(strict=False) ignore new parameters, introduced in my model's constructor?

For load_state_dict, the documentation states:
Whether you are loading from a partial *state_dict* , which is missing some keys, or loading a *state_dict* with more keys than the model that you are loading into, you can set the strict argument to **False** in the load_state_dict() function to ignore non-matching keys. (from https://pytorch.org/tutorials/beginner/saving_loading_models.html#id4)

I have created a new model by augmenting the architecture slightly, doubling the number of trainable parameters (by introducing more modules), but I want to retain the parameters from a checkpoint (but keep them trainable for fine-tuning in my model) with a different dataset. I want to make sure that the newly introduced parameters will be trained, instead of ignored when I set strict=False, or if I should look into doing something else. Thanks!

The load_state_dict() just loads the weights for you. Once you’ve loaded the weights using either of the arguments in strict, it’s done.

Suppose there is modelA which you trained previously and saved its weights as “modelA.pth”
Now, you create an exact model called modelB, and you want to initialise its weights with modelA's.
In this case,
ModelB.load_state_dict(torch.load("ModelA.pth")) would work.
Note, that if you use strict= False or strict=True here, there won’t be any error thrown.
The reason is, ModelA and and ModelB have the same kind of layers, there won’t be any problem in loading them.

But consider a scenario where your ModelB has some extra layers or is missing some layers.
Now, if you use something like
ModelB.load_state_dict(torch.load("ModelA.pth"),strict=True), it will throw error.
Keeping the parameter strict=True is like assuring PyTorch that both your models are identical.
On the contrary, if you use strict=False, you inform PyTorch that ModelB and ModelA are not identical, so it just initialises the parameters of layers which are present in both and ignores the rest.

P.S: The parameter strict has nothing to do with the parameters being trained or not.

9 Likes