Hello,
I wanted to raise the question on what is the most general and elegant way to load a pruned model (pruned using the utils.prune
functionality). There is an easy fix I’ll post as well but it is not satisfactory I think.
Problem
Given a model (I am using MS-D here https://github.com/ahendriksen/msd_pytorch), pruning introduces new parameters weight_orig
, weight_mask
etc. and also makes sure they are properly applied in forward/backward passes by hooks. If I simply define a model and try to load a pruned network I will get a key-error:
class MSDModel:
def __init__(...):
....
def save(self, path, epoch):
state = {
"epoch": int(epoch),
"state_dict": self.net.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
torch.save(state, path)
def simple_load(self, path, strict=True):
state = torch.load(path)
self.net.load_state_dict(state["state_dict"], strict=strict)
self.optimizer.load_state_dict(state["optimizer"])
self.net.cuda()
epoch = state["epoch"]
return epoch
model = MSDSegmentationModel(...)
model.simple_load(pruned_nets_path + netname)
This will give an unexpected ker-error and rightly so, it doesn’t know about weight_orig
etc. Calling model.load(..., strict=False)
will load the network without errors but then the new parameters are ignored, i.e. the model is loaded without the masks and such.
Easy solution
I wanted to include a quick workaround for anyone who wants a quick fix. The easiest fix for anyone who just wants it to work is to call a the predefined pruning method on the modules that are pruned in the network you are trying to load with a pruning percentage of 0%. This will set the correct hooks and introduce the correct parameters after which the loading will work fine.
General solution? (does not work yet)
In my opinion, the above is not really a proper solution. Instead, I would like to have a loading function which introduces the missing parameters in the dicts and sets the hooks properly. I have made a start below (including some code from: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/utils/torchtools.py) but this does not properly set the buffers, hooks etc. and simply does not work. The idea is that the expanded_loading
flag would allow the function to add new parameters etc. to the existing model.
from collections import OrderedDict
import warnings
class MSDModel:
def __init__(...):
...
def load(self, path, strict=True, expanded_loading=False):
state = t.load(path)
if 'state_dict' in state:
state_dict = state['state_dict']
else:
state_dict = state
if strict:
self.net.load_state_dict(state_dict)
else:
model_dict = self.net.state_dict()
new_state_dict = OrderedDict()
matched_keys, discarded_keys = [], []
if not expanded_loading:
for k, v in state_dict.items():
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_keys.append(k)
else:
discarded_keys.append(k)
else:
for k, v in state_dict.items():
new_state_dict[k] = v
matched_keys.append(k)
model_dict.update(new_state_dict)
self.net.load_state_dict(state_dict, strict=strict) # unnecessary?
if len(matched_keys) == 0:
warnings.warn(
'The pretrained weights cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'
)
else:
print(
'Successfully loaded pretrained weights.'
)
if len(discarded_keys) > 0:
print(
'** The following keys are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_keys)
)
self.optimizer.load_state_dict(state["optimizer"])
self.net.cuda()
epoch = state["epoch"]
return epoch
I am out of my depth here so I was hoping somebody would like to help me to write a proper loading function for pruned networks.
Kind regards,
Richard