How to load part of pre trained model?

You can remove all keys that don’t match your model from the state dict and use it to load the weights afterwards:

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
172 Likes

Wow, thanks a lot!
But “model.load(pretrained_dict)” gives me an error “object has no attribute ‘load’”

1 Like

I think #3 should be:

model.load_state_dict(model_dict)
17 Likes

Yes it should. I edited the snippet.

2 Likes

Thanks. I am looking for it. It seems awesome!!!

Consider the situation where I would like to restore all weights till the last layer.

# args has the model name, num classes and other irrelevant stuff
self._model = models.__dict__[args.arch](pretrained = False, 
                                         num_classes = args.classes, 
                                         aux_logits = False)


if self.args.pretrained:
              
    print("=> using pre-trained model '{}'".format(args.arch))
    pretrained_state = model_zoo.load_url(model_names[args.arch])
    model_state = self._model.state_dict()

    pretrained_state = { k:v for k,v in pretrained_state.iteritems() if k in model_state and v.size() == model_state[k].size() }
    model_state.update(pretrained_state)
    self._model.load_state_dict(model_state)

Shouldn’t we also be checking if the sizes match before restoring? It looks like we are comparing only the names.

4 Likes

if the sizes are wrong, i believe the copy_ invoked in load_state_dict will complain.

7 Likes

3. load the new state dict

model.load_state_dict(model_dict)

step 3 should look like this

5 Likes

@chenyuntc Hello, what if the net2 is a subset of net1, and I want to load weight from net2 to net1? can directly using load_static_dict works? thanks!

As of 21st December’17, load_state_dict() takes the boolean argument ‘strict’ which when set to False allows to load only the variables that are identical between the two models irrespective of whether one is subset/superset of the other.
http://pytorch.org/docs/master/_modules/torch/nn/modules/module.html#Module.load_state_dict

16 Likes

After model_dict.update(pretrained_dict), the model_dict may still have keys that pretrained_model doesn’t have, which will cause a error.

Assum following situation:

pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']

After pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} and model_dict.update(pretrained_dict), they are:

pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']

So when performing model.load_state_dict(pretrained_dict), model_dict still has key E that pretrained_dict doen’t have.

So how about using model.load_state_dict(model_dict) instead of model.load_state_dict(pretrained_dict)?

The complete snippet is therefore as follow:

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(model_dict)
54 Likes

hi
how I can load trained model with its weights ?

Have a look at the Serialization tutorial or do you have a specific issue?

1 Like

Is it param = Parameter.data ?

It might be a little late to ask. If I want to skip some layers, like if I train the model with batch normalization, but want to use the trained bn version for that without batch normalization, how can I change the layers’ names? Because otherwise, the name might be different, and it will complain about size mismatching.

This is very useful, thanks

Shouldn’t the # 3 be

model.load_state_dict(model_dict)

instead of

model.load_state_dict(pretrained_dict)

?

1 Like

if those params are stored in dictionary,
there is a great a lot parameters that have same names, eg. nn.Conv2d, really a lot.
So how could it know which nn.Conv2d is the right one to choose, when use this dictionary to mapping.
Thx~

@Zichun_Zhang see model.named_parameters(), it returns a key name and it’s corresponding parameter handle. Would that help?

<bound method Module.named_parameters of SiamRPN(
  (featureExtract): Sequential(
    (0): Conv2d(3, 192, kernel_size=(11, 11), stride=(2, 2))
    (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU(inplace)
    (4): Conv2d(192, 512, kernel_size=(5, 5), stride=(1, 1))
    (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): ReLU(inplace)
    (8): Conv2d(512, 768, kernel_size=(3, 3), stride=(1, 1))
    (9): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace)
    (11): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1))
    (12): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace)
    (14): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1))
    (15): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_r1): Conv2d(512, 10240, kernel_size=(3, 3), stride=(1, 1))
  (conv_r2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
  (conv_cls1): Conv2d(512, 5120, kernel_size=(3, 3), stride=(1, 1))
  (conv_cls2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
  (regress_adjust): Conv2d(20, 20, kernel_size=(1, 1), stride=(1, 1))
)>

I just tried your method on my net class, and it returned these. Therefore, does the names like (conv_r1) is the key, and corresponds to the names when I create variables, right? These are the keys?

BTW, thx a lot smth