Proper way to load a pruned network

Hello,

I wanted to raise the question on what is the most general and elegant way to load a pruned model (pruned using the utils.prune functionality). There is an easy fix I’ll post as well but it is not satisfactory I think.

Problem
Given a model (I am using MS-D here https://github.com/ahendriksen/msd_pytorch), pruning introduces new parameters weight_orig, weight_mask etc. and also makes sure they are properly applied in forward/backward passes by hooks. If I simply define a model and try to load a pruned network I will get a key-error:

class MSDModel:
    def __init__(...):
    ....
    def save(self, path, epoch):
        state = {
            "epoch": int(epoch),
            "state_dict": self.net.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        torch.save(state, path)

    def simple_load(self, path, strict=True):
        state = torch.load(path)
        self.net.load_state_dict(state["state_dict"], strict=strict)
        self.optimizer.load_state_dict(state["optimizer"])
        self.net.cuda()
        epoch = state["epoch"]
        return epoch

model = MSDSegmentationModel(...)
model.simple_load(pruned_nets_path + netname)

This will give an unexpected ker-error and rightly so, it doesn’t know about weight_orig etc. Calling model.load(..., strict=False) will load the network without errors but then the new parameters are ignored, i.e. the model is loaded without the masks and such.

Easy solution
I wanted to include a quick workaround for anyone who wants a quick fix. The easiest fix for anyone who just wants it to work is to call a the predefined pruning method on the modules that are pruned in the network you are trying to load with a pruning percentage of 0%. This will set the correct hooks and introduce the correct parameters after which the loading will work fine.

General solution? (does not work yet)
In my opinion, the above is not really a proper solution. Instead, I would like to have a loading function which introduces the missing parameters in the dicts and sets the hooks properly. I have made a start below (including some code from: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/utils/torchtools.py) but this does not properly set the buffers, hooks etc. and simply does not work. The idea is that the expanded_loading flag would allow the function to add new parameters etc. to the existing model.

from collections import OrderedDict
import warnings

class MSDModel:
    def __init__(...):
    ...
    def load(self, path, strict=True, expanded_loading=False):
        state = t.load(path)
        if 'state_dict' in state:
            state_dict = state['state_dict']
        else:
            state_dict = state

        if strict:
            self.net.load_state_dict(state_dict)
        else:
            model_dict = self.net.state_dict()
            new_state_dict = OrderedDict()
            matched_keys, discarded_keys = [], []

            if not expanded_loading:
                for k, v in state_dict.items():
                    if k in model_dict and model_dict[k].size() == v.size():
                        new_state_dict[k] = v
                        matched_keys.append(k)
                    else:
                        discarded_keys.append(k)
            else:
                for k, v in state_dict.items():
                    new_state_dict[k] = v
                    matched_keys.append(k)

            model_dict.update(new_state_dict)
            self.net.load_state_dict(state_dict, strict=strict) # unnecessary?

            if len(matched_keys) == 0:
                warnings.warn(
                    'The pretrained weights cannot be loaded, '
                    'please check the key names manually '
                    '(** ignored and continue **)'
                )
            else:
                print(
                    'Successfully loaded pretrained weights.'
                )
                if len(discarded_keys) > 0:
                    print(
                        '** The following keys are discarded '
                        'due to unmatched keys or layer size: {}'.
                        format(discarded_keys)
                    )
        self.optimizer.load_state_dict(state["optimizer"])
        self.net.cuda()

        epoch = state["epoch"]
        return epoch

I am out of my depth here so I was hoping somebody would like to help me to write a proper loading function for pruned networks.

Kind regards,

Richard

2 Likes

Remove the pruning before saving using prune.remove(layername, "weight"). This makes pruning permanent.

3 Likes

Yep, as @Jayant_Parashar said: remove the pruning reparametrization prior to saving the state_dict.

Yet another solution is to save out the whole model instead of the state dict while it’s still pruned:
torch.save(pruned_model, 'pruned_model.pth'), and then restore it as pruned_model = torch.load('pruned_model.pth'). This might be a bit risky because it assumes the model class can be easily found.

If, however, you care about retaining the masks, or you have inherited a state_dict from somewhere else which contains the pruned reparametrization (so the various weight_mask and weight_orig buffers and parameters), then the solution is to: 1) put your newly instantiated model in a pruned state using prune.identity, which creates all the objects you’d expect, but with masks of ones; 2) load the state_dict, which should now fit the model.

5 Likes

Let me also add that, in the last scenario, loading the state_dict into a newly instantiated model will make it such that all your weight_origs and weight_masks will be properly filled in with the info from the state_dict, BUT the weight will still be a randomly sampled tensor from the new model instantiation.

Why does this matter? weight_orig and weight_mask together can be used to recompute the weight on the fly, through the forward_pre_hook that PyTorch pruning uses. But this hook needs a forward call to act and recompute the weight . Without that call, the weight will just be some random tensor that has nothing to do with the loaded weight_orig and weight_mask.

Therefore, either serialize you models after removing the pruning parametrization, or remember to set the weight correctly (by hand or with a forward call) before trying to prune it again. In practice, this means preferably calling _ = model(X), where X is some (even fake) input data batch, or, alternatively, setting the weight by hand weight = weight_orig * weight_mask (be careful with this).

2 Likes

Thank you, this post was helpful. I wanted to add a chunk of code to this which is currently working for me and might help others. When you run prune identity on your model to load the checkpoint, there are two tricks, one is that you have to keep separate your checkpoints which are pruned and the ones which are not, secondly the procedure needs to know what parts of your model did get pruned. I was able to solve both of these with this, which may fail in weird ways, but hasn’t yet.

    for part in checkpoint.values():
        if hasattr( part, "keys" ):
            for key in part.keys():
                if "_mask" in key:
                    pieces = key.split('.')
                    model_section = getattr( model, pieces[0] )[int(pieces[1])]
                    prune.identity( model_section, pieces[2].replace( "_mask", "" ) )

It goes through the checkpoint to find the _mask labels and then identifies the attribute on the model to run prune identity.