Load part of pretrained model with strict=False in load_state_dict

Hi,

I have a pretrained model and I want to load just part of it in my model as Im replacing the classification layer. What Im doing is:

model.load_state_dict(state_dict, strict=False)

I need just a heads up that this is the correct way to go. Thank you very much.

1 Like

That sounds about right. strict = False should load all layers in the intersection of the pretrained model and your model.

It won’t do any magic, so the dimensions must still match.

1 Like

you can try this way simple and cool

current_model=net.state_dict()
keys_pre=torch.load('',map_location=device)

new_state_dict={k:v if v.size()==current_model[k].size()  else  current_model[k] for k,v in zip(current_model.keys(), keys_vin['model_state_dict'].values()) 
                 }