Parameter missing from state_dict

TLDR:
When I print out all the IDs present in optimizer_state['param_groups'], it seems like one of them is missing in optimizer_state['state'] (index 136).

I am using a modified version of HuggingFace code to load a model from a checkpoint on TPU and run into the following error.

Traceback (most recent call last):
File “transformers/examples/xla_spawn.py”, line 85, in
main()
File “transformers/examples/xla_spawn.py”, line 81, in main
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
File “/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py”, line 292, in spawn
_start_fn(0, pf_cfg, fn, args)
File “/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py”, line 229, in _start_fn
fn(gindex, *args)
File “/home/asd/source_code/Multilingual/transformers/examples/language-modeling/run_mlm_synthetic.py”, line 486, in _mp_fn
main()
File “/home/asd/source_code/Multilingual/transformers/examples/language-modeling/run_mlm_synthetic.py”, line 460, in main
trainer.train(model_path=model_path)
File “/home/asd/source_code/Multilingual/transformers/src/transformers/trainer_word_modifications.py”, line 666, in train
self._load_optimizer_and_scheduler(model_path)
File “/home/asd/source_code/Multilingual/transformers/src/transformers/trainer_word_modifications.py”, line 1003, in _load_optimizer_and_scheduler
self.optimizer.load_state_dict(optimizer_state)
File “/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch/optim/optimizer.py”, line 123, in load_state_dict
raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn’t match the size of optimizer’s group

After some poking around, I’ve attributed the error to the missing index in optimizer_state['state']. (PFA image). Any reason why this could be happening?

I’m using xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) to save the model, and self.optimizer.load_state_dict(optimizer_state) to load the model.