Why do we need state_dict = state_dict.copy()

Hello everyone,

I haven’t used PyTorch for a long time and there is one line code confused me very much. If any of you can help explain, I will appreciate it very much.

In my task I want to load the weights from the already pretrained model to my local model. But I don’t understand why we need this line " state_dict = state_dict.copy() " if the two have the same name " state_dict".

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):

        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')
    start_prefix = ''
    # print("hasattr(model, 'bert')",hasattr(model, 'bert')  ) :false
    if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
        start_prefix = 'bert.'
    load(model, prefix=start_prefix)

Just like how Huggingface use it in the following link(line: 669).

1 Like

I am also confused about this line. Have you figured it out? I got nothing after a whole day google

This might be helpful: Cache = self.state_dict() overwritten - #6 by ptrblck