Pytorch load_state_dict not work properly

I try to load a checkpoint to a pytorch module. But after I call the load_state_dict function, the parameters of the module are not aligned with the checkpoint file and all the results are noises.

I add an assertion to ensure the loaded parameters are equal with the checkpoint, and it was triggred.


I use pytorch 2.4.1. Can anyone have idea?

Is your model compiled using torch.compile before saving state_dict? If so, before loading state dictionary to model you need to compile model and then load state dictionary

I have solved the problem. The parameters of my model are in different devices (some of them in cpu and some of them in cuda). After I call model.cuda() before load_state_dict, the problem was resolved. I have no idea why this was occured.