Pytorch load_state_dict not work properly

I try to load a checkpoint to a pytorch module. But after I call the load_state_dict function, the parameters of the module are not aligned with the checkpoint file and all the results are noises.

I add an assertion to ensure the loaded parameters are equal with the checkpoint, and it was triggred.


I use pytorch 2.4.1. Can anyone have idea?

Is your model compiled using torch.compile before saving state_dict? If so, before loading state dictionary to model you need to compile model and then load state dictionary