Convert Pytorch checkpoint to state_dict only

I have a model checkpoint saved as torch.save(model, 'model.pt'), which includes the model definition. How can I load the state dict only without requiring any of the source code for the model definition?

I want something like torch.load('model.pt', state_dict_only=True) or something like that. I know this doesn’t exist but is there a similar hack?

Basically I have a checkpoint but none of the original code used to save the model. So I want to work with the state dict only.

Does loading the checkpoint and then saying .state_dict work for you?
Like so:

No because I don’t have the original code. So the line

model_checkpoint = torch.load(‘model.pt’)

breaks

so in what shape do you have the checkpoint saved? if its a dictionary you can access by checkpoint[“state_dict”]

From my experience, the state_dict alone is not usable without the model implementation class.
Can you either implement the model class or find a repository with the implementation and try the ff:

loaded_state_dict = torch.load('model.pt')['state_dict']
model.load_state_dict(loaded_state_dict)

Reply in case you run into state_dict key errors, etc…

Hi @Peter_Featherstone,

@toufiq is correct that you can only load a model if you have access to its original source code. All the state_dict object does is save the weights and buffers of an nn.Module to a dict, and when you load the state_dict object, it just matches the keys of the dict to your existing nn.Module.

Thank you all. Yes I understand the subtleties. Basically I have a pytorch model file I got from somewhere but i don’t have the code. I can reverse engineer the model but it won’t have exactly the same code. If i just had the state dict, i could figure out how to load the weights. But I just need pytorch to load the state dict without complaining it can’t find the the code. I’m tempted to hack the .pt file by unzipping it and manually extracting the state dict.

Can you explain what you mean by this? If you have an nn.Module with the same key structure, PyTorch shouldn’t complain, I think.