Loading state_dict after code refactoring gives import error

Using the standard code to save a model:

checkpoint = {
            'model': model,
            'state_dict': model.state_dict(),
            'optimizer': optimizer,
            'scheduler': scheduler
torch.save(checkpoint, path_out)

Since refactoring my code, the architecture Class is in a different name so i get an error trying to load the model: ModuleNotFoundError: No module named '.....' which makes sense, it isn’t there :wink:

My current solution is to create a fake module with an empty Class using the old module name, and once loaded, I just load the weights: model.load_state_dict(state_dict=state_dict)
But this solution is not elegant to say the least.

using load with weights_only returns similar error.

Is there a way to “force” load the file? or perhaps load only the state_dict without trying to load the “model” part?

You are not using the standard and recommended approach since you are storing the model directly besides its state_dict. Storing model will add the requirement to restore the same file location, so remove 'model': model and just store the state_dict instead to avoid such issues.

1 Like