How to load part of pre trained model?

How to load part of pre trained model?
Will the unused parameters be auto deleted if I load the whole model but only use part of it?

38 Likes

I’m afraid not

The keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

if name not in new model , it will raise KeyError
but I gusee this may work for you

    def load_my_state_dict(self, state_dict):
 
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)
31 Likes

You can remove all keys that don’t match your model from the state dict and use it to load the weights afterwards:

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
172 Likes

Wow, thanks a lot!
But “model.load(pretrained_dict)” gives me an error “object has no attribute ‘load’”

1 Like

I think #3 should be:

model.load_state_dict(model_dict)
17 Likes

Yes it should. I edited the snippet.

2 Likes

Thanks. I am looking for it. It seems awesome!!!

Consider the situation where I would like to restore all weights till the last layer.

# args has the model name, num classes and other irrelevant stuff
self._model = models.__dict__[args.arch](pretrained = False, 
                                         num_classes = args.classes, 
                                         aux_logits = False)


if self.args.pretrained:
              
    print("=> using pre-trained model '{}'".format(args.arch))
    pretrained_state = model_zoo.load_url(model_names[args.arch])
    model_state = self._model.state_dict()

    pretrained_state = { k:v for k,v in pretrained_state.iteritems() if k in model_state and v.size() == model_state[k].size() }
    model_state.update(pretrained_state)
    self._model.load_state_dict(model_state)

Shouldn’t we also be checking if the sizes match before restoring? It looks like we are comparing only the names.

4 Likes

if the sizes are wrong, i believe the copy_ invoked in load_state_dict will complain.

7 Likes

3. load the new state dict

model.load_state_dict(model_dict)

step 3 should look like this

5 Likes

@chenyuntc Hello, what if the net2 is a subset of net1, and I want to load weight from net2 to net1? can directly using load_static_dict works? thanks!

As of 21st December’17, load_state_dict() takes the boolean argument ‘strict’ which when set to False allows to load only the variables that are identical between the two models irrespective of whether one is subset/superset of the other.
http://pytorch.org/docs/master/_modules/torch/nn/modules/module.html#Module.load_state_dict

16 Likes

After model_dict.update(pretrained_dict), the model_dict may still have keys that pretrained_model doesn’t have, which will cause a error.

Assum following situation:

pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']

After pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} and model_dict.update(pretrained_dict), they are:

pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']

So when performing model.load_state_dict(pretrained_dict), model_dict still has key E that pretrained_dict doen’t have.

So how about using model.load_state_dict(model_dict) instead of model.load_state_dict(pretrained_dict)?

The complete snippet is therefore as follow:

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(model_dict)
54 Likes

hi
how I can load trained model with its weights ?

Have a look at the Serialization tutorial or do you have a specific issue?

1 Like

Is it param = Parameter.data ?

It might be a little late to ask. If I want to skip some layers, like if I train the model with batch normalization, but want to use the trained bn version for that without batch normalization, how can I change the layers’ names? Because otherwise, the name might be different, and it will complain about size mismatching.

This is very useful, thanks

Shouldn’t the # 3 be

model.load_state_dict(model_dict)

instead of

model.load_state_dict(pretrained_dict)

?

1 Like

if those params are stored in dictionary,
there is a great a lot parameters that have same names, eg. nn.Conv2d, really a lot.
So how could it know which nn.Conv2d is the right one to choose, when use this dictionary to mapping.
Thx~