Post load_state_dict hook

I have a use case for modifying a model’s parameters after its values get loaded but I don’t think PyTorch has a hook for load_state_dict. I currently just have a caching method that runs the transformation once on the first call the forward but this makes the model jit unfriendly. Is there a better way to handle this?

PyTorch uses the internal _register_state_dict_hook and _register_load_state_dict_pre_hook. However, since these methods are internal their interface could change without a deprecation warning, so I wouldn’t depend on them.

Based on your description it seems you are looking for a “post_state_dict_hook”?
If so, wouldn’t applying the desired manipulations to the model parameters after calling model.load_state_dict work? I’m unsure why a hook would be needed for it.