Transferring pruning hooks

Hello!
I am currently having diffiiculties with implementing the lottery ticket hypothesis.
I got a trained and pruned model, and i am trying to transfer the weight masks from the trained network to the untrained one.
I have tried this but it has no effect:

def transfer_hooks(trained_model,fresh):
    fresh.conv1.register_forward_pre_hook(trained_model.conv1._forward_pre_hooks)

Also, deepcopy seems to be broken when copying pruned layers, throwing

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

I have also tried with

copymodel.load_state_dict(torch.load(model_path))

but it only works if i apply pruning once on the copymodel before loading, else it throws:

RuntimeError: Error(s) in loading state_dict for network:
	Missing key(s) in state_dict: "conv1.weight", "conv2.weight", "fc1.weight", "fc2.weight". 
	Unexpected key(s) in state_dict: "conv1.weight_orig", "conv1.weight_mask", "conv2.weight_orig", "conv2.weight_mask", "fc1.weight_orig", "fc1.weight_mask", "fc2.weight_orig", "fc2.weight_mask". 

is there a more elegant solution out there, that doesnt require to apply a pruning step on the network beforehand and simply allows me to transfer the “pruning” component?

You pruned model is in a pruned reparametrization state, where pruned weights have been replaced as such: weight --> weight_orig, weight_mask.
Now, if you try to load those into an unpruned model via a state_dict, it will of course fail, because that reparametrization doesn’t exist in your unpruned model, and the keys in the state_dicts don’t match.

Instructions on how to solve this are provided in these answers: Proper way to load a pruned network

Let me know if you have any further questions.

Hey @Dan_Blanaru , any luck with transferring the forward pre hook from checkpoint to a newly instantiated model? If I understand correctly, PyTorch hooks are applied to the model instance and not to the parameters. So unless one saves the model instance along with the state_dict, the hooks are lost?