Key Truncation Issue in Checkpoint Save/Load

Hi everyone,

I am currently facing an issue when trying to load a checkpoint after training. The problem arises from a mismatch between the keys in the state_dict, resulting in the following error message:

RuntimeError: Error(s) in loading state_dict for MyLightningModule:
Missing key(s) in state_dict: "model.llm.embeddings.word_embeddings.weight",
Unexpected key(s) in state_dict: "ings.word_embeddings.weight",

It seems that during the checkpoint save process, only “ings” from model.llm.embeddings is being preserved, while the rest of the key is getting cut off.

Could someone help me understand why this key truncation is happening during the save operation? I would greatly appreciate any insights or guidance you could provide. Thank you very much!

I haven’t seen such an issue before pointing to missing parts of the keys.
Is this the only key that shows the mismatch?

Thank you so much for your reply! The mismatch happens to all keys. For instance, “model.pre_classifier.weight” and “model.pre_classifier.bias” were truncated to “fier.weight” and “fier.bias,” respectively. It seems that the first 16 characters were consistently removed from all keys in the saved checkpoint.

This is quite strange, and I greatly appreciate your insights into why this issue is occurring. Thanks again!

I have identified that the issue is attributed to the Ray package, and I will reach out to them to seek a solution. Thank you for your help!

That’s interesting! If you have a chance, I would be really interested in the root cause as it’s indeed a weird but interesting issue.

It turns out that RayFSDPStrategy from Ray truncates keys in state_dict using the following code:

state_dict = self.model.state_dict()
prefix_len = len("_forward_module.")
return {k[prefix_len:]: v for k, v in state_dict.items()}