I saved my pre-trained model and load it when need re-training, but sometimes I modify the net’s structure and I want the program to automatically check whether the parameters still fit, if not, train the net from scratch. Are there any methods can do this?
i did it this way:
# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3 ######
def load_valid(model, pretrained_file, skip_list=None):
pretrained_dict = torch.load(pretrained_file)
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict1 = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in skip_list }
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict1)
model.load_state_dict(model_dict)
##-------------------------
#example on usage
pretrained_file = '/root/share/project/pytorch/data/pretrain/inception/inception_v3_google-1a9a5a14.pth'
net=Inception3(num_classes=10)
if pretrained_file is not None: #pretrain
skip_list = ['fc.weight', 'fc.bias']
load_valid(net, pretrained_file, skip_list=skip_list)