Hello everyone,
I haven’t used PyTorch for a long time and there is one line code confused me very much. If any of you can help explain, I will appreciate it very much.
In my task I want to load the weights from the already pretrained model to my local model. But I don’t understand why we need this line " state_dict = state_dict.copy() " if the two have the same name " state_dict".
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
# print("hasattr(model, 'bert')",hasattr(model, 'bert') ) :false
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
Just like how Huggingface use it in the following link(line: 669).