I ask a question of how to load weight for a modified net

I train a net and save the weight.But I modify some details in net so that the function of ‘torch.load’ can not be used,so how can I load the weight for modified net. Please help me thank you

I use follow code to load pretrain model, maybe it will works for you:
model = alexnet()
model_dict = model.state_dict()
pretrained_dict = torch.load(“xxx.pkl”)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

1 Like

Thank for you help very much