Strict=false in load_stat_dict

What strict=false do in load_stat_dict?

I read it load with missing parameter. For an example if i have module of 4 convolution layer followed by BN and RelU. Then if i have pth file of 3 convolution layer followed by BN and RelU OR 5 convolution layer followed by BN and RelU then it is possible to load weights using this argument. Am i right?

It will if the ones that match have the same name and the same paramters.

In general, it is used if you extend a given Module to add extra stuff but you still want to be able to load a checkpoint from the original module that contains the paramters from all the common pieces.

1 Like

To expand, state_dict is like a normal python dictionary. Default strict=True means that when the model loads, it will work if and only if the dictionary has keys with the exact same name as the parameters of the model AND nothing else.

With strict=False, you are saying that you don’t care if the parameters which are not included (by name) in the dictionary don’t get loaded and to ignore any keys which the module does not use.

strict=False is useful for example when you created a Module definition which only covers part of a previous module (think feature extractor made from a classification net), or when you want to expand a network and preload some of the weights from an existing module. In both cases, you need to be very careful to get the naming right or manually rename the keys in the dictionary.

1 Like

EDIT: Nvm, found my answer in the official tutorials
https://pytorch.org/tutorials/beginner/saving_loading_models.html#id4

Whether you are loading from a partial state_dict , which is missing some keys, or loading a state_dict with more keys than the model that you are loading into, you can set the strict argument to False in the load_state_dict() function to ignore non-matching keys.

Just to be sure, will it work with strict=False for this case:

model = ['x', 'y', 'z' , 'a']
state_dict = ['x', 'y', 'z', 'b']

I don’t care about ‘a’ and ‘b’, I just want to load x, y, z

1 Like

Yes it will just ignore these if strict=False!

1 Like