How to load part of pre trained model?
Will the unused parameters be auto deleted if I load the whole model but only use part of it?
Iâm afraid not
The keys of
state_dict
must exactly match the keys returned by this moduleâsstate_dict()
function.
if name not in new model , it will raise KeyError
but I gusee this may work for you
def load_my_state_dict(self, state_dict):
own_state = self.state_dict()
for name, param in state_dict.items():
if name not in own_state:
continue
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)
You can remove all keys that donât match your model from the state dict and use it to load the weights afterwards:
pretrained_dict = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
Wow, thanks a lot!
But âmodel.load(pretrained_dict)â gives me an error âobject has no attribute âloadââ
I think #3 should be:
model.load_state_dict(model_dict)
Yes it should. I edited the snippet.
Thanks. I am looking for it. It seems awesome!!!
Consider the situation where I would like to restore all weights till the last layer.
# args has the model name, num classes and other irrelevant stuff
self._model = models.__dict__[args.arch](pretrained = False,
num_classes = args.classes,
aux_logits = False)
if self.args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
pretrained_state = model_zoo.load_url(model_names[args.arch])
model_state = self._model.state_dict()
pretrained_state = { k:v for k,v in pretrained_state.iteritems() if k in model_state and v.size() == model_state[k].size() }
model_state.update(pretrained_state)
self._model.load_state_dict(model_state)
Shouldnât we also be checking if the sizes match before restoring? It looks like we are comparing only the names.
if the sizes are wrong, i believe the copy_
invoked in load_state_dict
will complain.
3. load the new state dict
model.load_state_dict(model_dict)
step 3 should look like this
@chenyuntc Hello, what if the net2 is a subset of net1, and I want to load weight from net2 to net1? can directly using load_static_dict works? thanks!
As of 21st Decemberâ17, load_state_dict() takes the boolean argument âstrictâ which when set to False allows to load only the variables that are identical between the two models irrespective of whether one is subset/superset of the other.
http://pytorch.org/docs/master/_modules/torch/nn/modules/module.html#Module.load_state_dict
After model_dict.update(pretrained_dict)
, the model_dict
may still have keys that pretrained_model
doesnât have, which will cause a error.
Assum following situation:
pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']
After pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
and model_dict.update(pretrained_dict)
, they are:
pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']
So when performing model.load_state_dict(pretrained_dict)
, model_dict
still has key E
that pretrained_dict
doenât have.
So how about using model.load_state_dict(model_dict)
instead of model.load_state_dict(pretrained_dict)
?
The complete snippet is therefore as follow:
pretrained_dict = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
hi
how I can load trained model with its weights ?
Have a look at the Serialization tutorial or do you have a specific issue?
Is it param = Parameter.data ?
It might be a little late to ask. If I want to skip some layers, like if I train the model with batch normalization, but want to use the trained bn version for that without batch normalization, how can I change the layersâ names? Because otherwise, the name might be different, and it will complain about size mismatching.
This is very useful, thanks
Shouldnât the # 3 be
model.load_state_dict(model_dict)
instead of
model.load_state_dict(pretrained_dict)
?
if those params are stored in dictionary,
there is a great a lot parameters that have same names, eg. nn.Conv2d, really a lot.
So how could it know which nn.Conv2d is the right one to choose, when use this dictionary to mapping.
Thx~