How to load part of pre trained model?

(Albert Xavier) #1

How to load part of pre trained model?
Will the unused parameters be auto deleted if I load the whole model but only use part of it?

15 Likes
Some detailed problem about torch.load_state_dict()
Loading a few layers from a pretrained MDNet
(Yun Chen) #2

I’m afraid not

The keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

if name not in new model , it will raise KeyError
but I gusee this may work for you

    def load_my_state_dict(self, state_dict):
 
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)
9 Likes
How to train this model on multi GPUs
How to load part of model when the model contains more layers than the save weights?
Pre-training problem in pytorch
Unexpected key in state_dict: "bn1.num_batches_tracked"
Some detailed problem about torch.load_state_dict()
(Adam Paszke) #3

You can remove all keys that don’t match your model from the state dict and use it to load the weights afterwards:

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
82 Likes
Loading weights from pretrained model with different module names
Unable to load model state after modifications (missing key(s) in state_dict)
(Albert Xavier) #5

Wow, thanks a lot!
But “model.load(pretrained_dict)” gives me an error “object has no attribute ‘load’”

(Albert Xavier) #6

I think #3 should be:

model.load_state_dict(model_dict)
9 Likes
(Adam Paszke) #7

Yes it should. I edited the snippet.

1 Like
(Chamsu) #8

Thanks. I am looking for it. It seems awesome!!!

(Prasanna S) #9

Consider the situation where I would like to restore all weights till the last layer.

# args has the model name, num classes and other irrelevant stuff
self._model = models.__dict__[args.arch](pretrained = False, 
                                         num_classes = args.classes, 
                                         aux_logits = False)


if self.args.pretrained:
              
    print("=> using pre-trained model '{}'".format(args.arch))
    pretrained_state = model_zoo.load_url(model_names[args.arch])
    model_state = self._model.state_dict()

    pretrained_state = { k:v for k,v in pretrained_state.iteritems() if k in model_state and v.size() == model_state[k].size() }
    model_state.update(pretrained_state)
    self._model.load_state_dict(model_state)

Shouldn’t we also be checking if the sizes match before restoring? It looks like we are comparing only the names.

3 Likes
#10

if the sizes are wrong, i believe the copy_ invoked in load_state_dict will complain.

3 Likes
#13

3. load the new state dict

model.load_state_dict(model_dict)

step 3 should look like this

4 Likes
#14

@chenyuntc Hello, what if the net2 is a subset of net1, and I want to load weight from net2 to net1? can directly using load_static_dict works? thanks!

(Karttikeya Mangalam) #15

As of 21st December’17, load_state_dict() takes the boolean argument ‘strict’ which when set to False allows to load only the variables that are identical between the two models irrespective of whether one is subset/superset of the other.
http://pytorch.org/docs/master/_modules/torch/nn/modules/module.html#Module.load_state_dict

9 Likes
(KAI ZHAO) #16

After model_dict.update(pretrained_dict), the model_dict may still have keys that pretrained_model doesn’t have, which will cause a error.

Assum following situation:

pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']

After pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} and model_dict.update(pretrained_dict), they are:

pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']

So when performing model.load_state_dict(pretrained_dict), model_dict still has key E that pretrained_dict doen’t have.

So how about using model.load_state_dict(model_dict) instead of model.load_state_dict(pretrained_dict)?

The complete snippet is therefore as follow:

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(model_dict)
23 Likes
(romina) #17

hi
how I can load trained model with its weights ?

#18

Have a look at the Serialization tutorial or do you have a specific issue?

(Sriharsha Annamaneni) #19

Is it param = Parameter.data ?

#20

It might be a little late to ask. If I want to skip some layers, like if I train the model with batch normalization, but want to use the trained bn version for that without batch normalization, how can I change the layers’ names? Because otherwise, the name might be different, and it will complain about size mismatching.

(Yufan Xue) #21

This is very useful, thanks

(Nurlan) #24

Shouldn’t the # 3 be

model.load_state_dict(model_dict)

instead of

model.load_state_dict(pretrained_dict)

?

(Cipher) #25

if those params are stored in dictionary,
there is a great a lot parameters that have same names, eg. nn.Conv2d, really a lot.
So how could it know which nn.Conv2d is the right one to choose, when use this dictionary to mapping.
Thx~