Delete state_dict keys

How can I remove model state_dict keys?

model_checkpoint_path = 'xyz.ckpt'
model_checkpoint_load = torch.load(model_checkpoint_path, map_location='cpu')
model_state_dict = model_checkpoint_load['state_dict']
model_state_dict = model_state_dict.copy()

for key in model_state_dict.keys():
    
    if key.startswith('loss'):
        model_state_dict.pop(key)
        # print(key) # ==> This would not throw an error

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[89], line 1
----> 1 for key in model_state_dict.keys():
      3     if key.startswith('loss'):
      4         model_state_dict.pop(key)

RuntimeError: OrderedDict mutated during iteration

Thanking you.

You could collect the keys you want to remove in the loop and delete it afterwards:

model = models.resnet18()

sd = model.state_dict()

layers_to_remove = []
for key in sd:
    if "conv" in key:
        layers_to_remove.append(key)
print(layers_to_remove)
# ['conv1.weight', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.1.conv1.weight', 'layer1.1.conv2.weight', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.1.conv1.weight', 'layer2.1.conv2.weight', 'layer3.0.conv1.weight', 'layer3.0.conv2.weight', 'layer3.1.conv1.weight', 'layer3.1.conv2.weight', 'layer4.0.conv1.weight', 'layer4.0.conv2.weight', 'layer4.1.conv1.weight', 'layer4.1.conv2.weight']

print(len(sd))
# 122
print(len(layers_to_remove))
# 17

for key in layers_to_remove:
    del sd[key]
print(len(sd))
# 105
2 Likes

Thank you @ptrblck :slight_smile: