How can I remove model state_dict keys?
model_checkpoint_path = 'xyz.ckpt'
model_checkpoint_load = torch.load(model_checkpoint_path, map_location='cpu')
model_state_dict = model_checkpoint_load['state_dict']
model_state_dict = model_state_dict.copy()
for key in model_state_dict.keys():
if key.startswith('loss'):
model_state_dict.pop(key)
# print(key) # ==> This would not throw an error
Error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[89], line 1
----> 1 for key in model_state_dict.keys():
3 if key.startswith('loss'):
4 model_state_dict.pop(key)
RuntimeError: OrderedDict mutated during iteration
Thanking you.