Parameter missing from state_dict

When I print out all the IDs present in optimizer_state['param_groups'], it seems like one of them is missing in optimizer_state['state'] (index 136).

I am using a modified version of HuggingFace code to load a model from a checkpoint on TPU and run into the following error.

Traceback (most recent call last):
File “transformers/examples/”, line 85, in
File “transformers/examples/”, line 81, in main
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
File “/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch_xla/distributed/”, line 292, in spawn
_start_fn(0, pf_cfg, fn, args)
File “/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch_xla/distributed/”, line 229, in _start_fn
fn(gindex, *args)
File “/home/asd/source_code/Multilingual/transformers/examples/language-modeling/”, line 486, in _mp_fn
File “/home/asd/source_code/Multilingual/transformers/examples/language-modeling/”, line 460, in main
File “/home/asd/source_code/Multilingual/transformers/src/transformers/”, line 666, in train
File “/home/asd/source_code/Multilingual/transformers/src/transformers/”, line 1003, in _load_optimizer_and_scheduler
File “/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch/optim/”, line 123, in load_state_dict
raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn’t match the size of optimizer’s group

After some poking around, I’ve attributed the error to the missing index in optimizer_state['state']. (PFA image). Any reason why this could be happening?

I’m using, os.path.join(output_dir, "")) to save the model, and self.optimizer.load_state_dict(optimizer_state) to load the model.