Proper way to load a pruned network

Yep, as @Jayant_Parashar said: remove the pruning reparametrization prior to saving the state_dict.

Yet another solution is to save out the whole model instead of the state dict while it’s still pruned:
torch.save(pruned_model, 'pruned_model.pth'), and then restore it as pruned_model = torch.load('pruned_model.pth'). This might be a bit risky because it assumes the model class can be easily found.

If, however, you care about retaining the masks, or you have inherited a state_dict from somewhere else which contains the pruned reparametrization (so the various weight_mask and weight_orig buffers and parameters), then the solution is to: 1) put your newly instantiated model in a pruned state using prune.identity, which creates all the objects you’d expect, but with masks of ones; 2) load the state_dict, which should now fit the model.

5 Likes