I have trained and saved a model (Model A) to a file. Then, I create a new model object Model B whose structure is different but similar with Model A. I want to load the selected parameters from Model A to the selected parameters of Model B. For example, to a fully connected layer, I just want to load parameters with value>0.5 of weight part but do not need bias part.
I read the code of Module class of pytorch and it use load_state_dict to load parameters. The main part of code is:
def load_state_dict(self, state_dict, strict=True):
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(self)
It use _load_from_state_dict to load parameters, and main part of _load_from_state_dict is:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
if isinstance(input_param, Parameter):
# backwards compatibility for serialized parameters
input_param = input_param.data
try:
param.copy_(input_param)
except Exception:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(key, param.size(), input_param.size()))
elif strict:
missing_keys.append(key)
It seems to copy the data of weight and bias in each layer from save file to new model so that I write a function to do same things. For convenience, I just use same model to save and load.
def load_para(model, ckpt_file, load_to_cpu=True):
map_location = (lambda storage, loc: storage) if load_to_cpu else None
ckpt = torch.load(ckpt_file, map_location=map_location) # The model and optimizer is loaded from the saved file
para_dict = ckpt['state_dicts'][0] # Get the state dict
for n,p in model.named_parameters():
ip = para_dict[n] # Get the parameters with selected layer name from save file
if p.shape == ip.shape:
p.data.copy_(ip.data) # Copy the data of parameters
else:
print('{} -shape {} ,{}'.format(n, (p.shape), (ip.shape)))
print(judge_equal_para(model, ckpt_file, load_to_cpu)) # Judge whether the parameter of saved file and model after loading is equal
I also use the method from the office tutorials:
def load_para(model, ckpt_file, load_to_cpu=True):
map_location = (lambda storage, loc: storage) if load_to_cpu else None
ckpt = torch.load(ckpt_file, map_location=map_location)
para_dict = ckpt['state_dicts'][0]
model.load_state_dict(para_dict)
print(judge_equal_para(model, ckpt_file, load_to_cpu))
The function to judge whether the parameter of saved file and model after loading is equal:
def judge_equal_para(model, ckpt_file, load_to_cpu=True):
map_location = (lambda storage, loc: storage) if load_to_cpu else None
ckpt = torch.load(ckpt_file, map_location=map_location)
para_dict = ckpt['state_dicts'][0].copy()
judge = True
for n,p in model.named_parameters():
ip = para_dict[n]
if p.shape == ip.shape:
if (p.data!=ip.data).all(): # judge whether the parameter of saved file and model after loading is equal
return False
else:
pass
return judge
The result from method of the office tutorials is:
True # he parameter of saved file and model after loading is equal
epoch[1]: r@1=0.949940, r@5=0.973778, r@10=0.985101 # Larger will be better
But the result from my method is:
True #he parameter of saved file and model after loading is equal
epoch[1]: r@1=0.008343, r@5=0.029797, r@10=0.045888
I do not know where is wrong and I hope you can give me a hand.