KeyError: xxxxxxxxxx when calling optimizer.state_dict()

Hello there,

I am training a deep convolution model unsupervised-ly based on deepcluster pipeline. The approach alternates as follows:

0- remove classifier module
1- feed inputs
2- compute features
3- cluster inputs based on k means
4- add new classifier module
5- use inputs cluster assignments as pseudo labels to train the model

Since I am adding and removing classifier module on the go, I am using optimizer.add_param_group() to add new optimization parameters for the added classifier module and using del optimizer.param_groups[1] to delete the parameters when the module is removed.

The problem is that when saving the model and the optimizer state I call optimizer.state_dict() to get the optimizer state and save it. However the following error is fired:

/root/module/class.py in save_model_parameters(self, model_parameters_path, epoch, optimizer)
    254                       'state_dict': self.state_dict()}
    255         if optimizer:
--> 256             model_dict['optimizer'] = optimizer.state_dict()
    257 
    258         torch.save(model_dict, model_parameters_path)

/usr/local/lib/python3.6/dist-packages/torch/optim/optimizer.py in state_dict(self)
     96         # Remap state to use order indices as keys
     97         packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
---> 98                         for k, v in self.state.items()}
     99         return {
    100             'state': packed_state,

/usr/local/lib/python3.6/dist-packages/torch/optim/optimizer.py in <dictcomp>(.0)
     96         # Remap state to use order indices as keys
     97         packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
---> 98                         for k, v in self.state.items()}
     99         return {
    100             'state': packed_state,

KeyError: 140413627951216

It seems that the problem resides in the following part of optimizer.state_dict() method:
param_mappings[id(k)]

The id of a certain parameter in the optimizer.state doesn’t exist in param_mappings.
The error is weird and is raised on some environment (colab, cloud) and not on others (local).
Do you have any thoughs?

I believe that the problem resides in the fact that I am removing param_groups as following:
del optimizer.param_groups[1]
This is removing the second param group from the optimizer but it seems that it’s associated parameters in optimizer.state reside leading to an error when matching param_mappings with optimizer.state.
If that is the case then what is the right way to remove param_groups? (no documentation regarding this in PyTorch optimizer documentation)