Hi, I have model A which is pretrained and Model B which is new. Model B has all of the layers of model A + an extra layer.
I want to load the weights from Model A → B for the layers they have in common.
I implemented the following:
pretrained_path = torch.load(path to pretrained model)
new_model_dict = model.state_dict()
pretrained_weights = { k:v for k , v in pretrained_path.items() if k in new_model_dict}
new_model_dict.update(pretrained_weights)
model.load_state_dict(pretrained_weights, strict = False)
I had to use strict = False
because the extra layer in model B was giving me an error. But this has made my results pretty bad.
How do I fix this? I want to avoid strict = False