Hi, I’m trying to modify nn.Module._load_from_state_dict()
(torch.nn.modules.module — PyTorch 1.7.1 documentation).
My class inherits nn.Module
:
class foo(nn.Module):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# The following lines are copied from the first few lines in the original _load_from_state_dict()
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
pdb.set_trace()
However, I found that it doesn’t seem so simple. I found that self._load_state_dict_pre_hooks
, persistent_buffers
, local_name_params
are now empty dictionaries when I stepped into it in pdb.set_trace()
.
Can someone explain why this might be and how I may modify _load_from_state_dict()
?