Modifing nn.Module._load_from_state_dict

Hi, I’m trying to modify nn.Module._load_from_state_dict() (torch.nn.modules.module — PyTorch 1.7.1 documentation).

My class inherits nn.Module:

class foo(nn.Module):

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):

        # The following lines are copied from the first few lines in the original _load_from_state_dict()
        for hook in self._load_state_dict_pre_hooks.values():
            hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

        persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
        local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
        local_state = {k: v for k, v in local_name_params if v is not None}

However, I found that it doesn’t seem so simple. I found that self._load_state_dict_pre_hooks, persistent_buffers, local_name_params are now empty dictionaries when I stepped into it in pdb.set_trace().

Can someone explain why this might be and how I may modify _load_from_state_dict()?


All functions in python that start with ._foo are considered private and not part of the user API (you can see it is not documented).
Why do you want to modify that function?

Currently, if the dimensions of the tensor don’t match, then _load_from_state_dict() will ignore that tensor. However, I want it to try and copy the parameters anyway. For example, I may change the dimension of a tensor from (10, 100) to (10, 110). I would still like to copy over the first 100 columns of the tensor as starting values.

I think the simplest way to do that is to pre-process your state_dict before giving it to the Module.load_state_dict function. That will make sure that your code keeps working with new versions of PyTorch as internal API might change without notice.

1 Like