Delete state_dict keys

How can I remove model state_dict keys?

model_checkpoint_path = 'xyz.ckpt'
model_checkpoint_load = torch.load(model_checkpoint_path, map_location='cpu')
model_state_dict = model_checkpoint_load['state_dict']
model_state_dict = model_state_dict.copy()

for key in model_state_dict.keys():
    
    if key.startswith('loss'):
        model_state_dict.pop(key)
        # print(key) # ==> This would not throw an error

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[89], line 1
----> 1 for key in model_state_dict.keys():
      3     if key.startswith('loss'):
      4         model_state_dict.pop(key)

RuntimeError: OrderedDict mutated during iteration

Thanking you.

You could collect the keys you want to remove in the loop and delete it afterwards:

model = models.resnet18()

sd = model.state_dict()

layers_to_remove = []
for key in sd:
    if "conv" in key:
        layers_to_remove.append(key)
print(layers_to_remove)
# ['conv1.weight', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.1.conv1.weight', 'layer1.1.conv2.weight', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.1.conv1.weight', 'layer2.1.conv2.weight', 'layer3.0.conv1.weight', 'layer3.0.conv2.weight', 'layer3.1.conv1.weight', 'layer3.1.conv2.weight', 'layer4.0.conv1.weight', 'layer4.0.conv2.weight', 'layer4.1.conv1.weight', 'layer4.1.conv2.weight']

print(len(sd))
# 122
print(len(layers_to_remove))
# 17

for key in layers_to_remove:
    del sd[key]
print(len(sd))
# 105
2 Likes

Thank you @ptrblck :slight_smile:

Hi @ptrblck, I know this post has been marked as solved, but I am having a similar issue. Now I have a model that I’m quantizing. Before quantization, it goes through some calibration. Problem is, when I try to load the model from a checkpoint to resume training after server timeout, I get a missing keys error. The calibration features are not saved. I removed those from the state dict and saved in sd as you suggested, but how do I load the new sd into the model so that I can correctly load the saved model?

Could you post a code snippet showing how you are quantizing, storing, and reloading the model?

I quantize using IBM’s foundation model stack: github.com/foundation-model-stack/fms-model-optimizer. The call to quantize is:

self.model, optimizer = qmodel_prep(self.model, data_mb, self.qcfg,

                            optimizer=optimizer,
                            save_fname=f"{base_name}_graph.pt",
                            qlast=self.args.qlast)`

Saving a model:

def save(self, model, path):

 state = {}
 self.save_to_state(model, state)

 path.parent.mkdir(parents=True, exist_ok=True)
 torch.save(state, path)`

def save_to_state(self, model, state):

    state["model_hash"]  = self.model_hash
    state["model_state"] = model.state_dict()`

Loading saved checkpoint:

def _default_dict_loader(model, model_state):

    model.load_state_dict(model_state)
    # if I set strict = False, it loads from the checkpoint, but testing returns all nans

def load(self, model, path, device, checkHash=True):

    state = torch.load(path, map_location=device)
    self.load_from_state(model, state, checkHash)

def load_from_state(self, model, state, checkHash=True):

    if checkHash and self.model_hash != state["model_hash"]:
        msg = "ERROR: saved model does not match current model params"
        print(msg) # log if we are logging
        print(msg, file=sys.stderr)
        exit(1)

    self.dict_loader(model, state["model_state"])

I used your example to modify the state_dict after the qmodel prep call, then attempt to load the new state_dict into the model using:

self.model.load_state_dict(state_dict, strict=False)
# if strict=True, this fails.

Unfortunately, I’m not familiar with this repository and also don’t know how the model_hash is created, which seems to fail in your code?
Do you get any proper error message while trying to load the state_dict?

Below is a snippet of the error messages. There are many more keys that are missing. There are quantized features in the saved model, however, those specific calibration features were not saved.

Screenshot 2025-04-17 at 7.05.49 AM
Screenshot 2025-04-17 at 7.05.12 AM

I got the checkpoint to load correctly. I disabled both calibration functions and then set strict = False when loading the model from the checkpoint. The model hash was created correctly. The issue was calibration features were not saved after the first round of training. Two sets of calibration are automatically performed if their arguments are > 0, but they aren’t necessary when resuming QAT from a checkpoint, for obvious reasons. The model is required to be prepared for quantization once resuming from a checkpoint. Some calibration features are still loaded even though calibration is turned off, but these can be ignored during load, since they are discarded during training.