Invoking Time of nn.Module _register_state_dict_hook()

Below is the source code of _register_state_dict_hook I quoted from official Pytorch nn.Module

    def _register_state_dict_hook(self, hook):
        r"""These hooks will be called with arguments: `self`, `state_dict`,
        `prefix`, `local_metadata`, after the `state_dict` of `self` is set.
        Note that only parameters and buffers of `self` or its children are
        guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
        inplace or return a new one.
        """
        handle = hooks.RemovableHandle(self._state_dict_hooks)
        self._state_dict_hooks[handle.id] = hook
        return handle

The comment says that “These hooks will be called … after the “state_dict” of self is set”.
While I’m not sure when exactly the “state_dict” of a nn.Module is set. Is it set after initialization? Or every time after the weight get updated?

Hi,

This is linked to when the load_state_dict() function is called.
Note that this function is internal and might change without notice in the future.

1 Like

Aha, I see.
So the same goes for nn.Module _reister_state_dict_pre_hook(), right?

Yes this is the same idea

1 Like

Hi @albanD , a simple followup on this thread, I want to filter out some parameters before save model checkpoints, and I thought I could use _register_state_dict_hook to modify state_dict (or the destination dict), but now I’m not sure if this is the best way in pytorch, because you said this function is internal and may change.

So what’s the proper way to change state_dict and filter out some params (don’t save them into model ckpt files)?

Hey!

The full proposal to fix this is at [RFC] Consolidated and unified state_dict and load_state_dict hooks · Issue #75287 · pytorch/pytorch · GitHub
Contributions adding the missing pieces are welcomed!