Saving pytorch model still not understood

I really need help! After reading all the things I have found on internet, e.g.:


https://pytorch.org/docs/master/notes/serialization.html

Either its to complicate for beginners like me, either I found not working things.

As I understood there is two main approaches:

  1. This approach is recommanded, but to me its not convenient as I need to first creating an instance of the model before downloading (but when using the model in other notebook I dont necessarily want to download again all the information).

torch.save(model.state_dict(), os.path.join(path_,‘model_’+name_+’.pt’))
#then to download it again:
model2 = Net()
model2.load_state_dict(torch.load(os.path.join(path_,‘model_’+name_+’.pt’)))
model2.eval() #Sets in evaluation mode is needed for dropout and batchnorm

  1. Its not the recommanded one, but I dont understand why. Here I again need to download the class definition. But, then why wouldn’t I just save the model state_dict using torch.save(model.state_dict(), …) ? Also, I read that in this situation I should not put the model into eval mode. But this is only if I want to train it directly afterwards no? If I want to make more evaluation, test it on new data I can put the model under eval mode no? And then under train mode again for training I guess.

state = {‘model’:model,‘state_dict’:model.state_dict()}
torch.save(state, os.path.join(path_,‘model_’+name_+’.tar’))
#and to load it again:
state = torch.load(os.path.join(path,‘model_’+name_+’.tar’))
model2 = state[‘model’]
model2.load_state_dict(state[‘state_dict’])

I simply need to save the model in order to use it without needing to download all the class definition again Also, if I could also understand how to save it in a way that I can train it again it would be very useful too.

Hi,

Before my answer, could you tell me where checkpoint came from in the below snippet?


All the answer below is derived from my little understanding.

As to your question 2 and 3.

When you save models by torch.save, pytorch saves states_dict and the source code of the model ref using inspect.
This means

the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

as written in https://pytorch.org/docs/master/notes/serialization.html.

So, in the long run, it’s preferable to save model and its state_dict separately though you don’t like the process.

Thank you @crcrpar for your answer!
You are right its a typo from my side, I will edit it now, sorry for this.
I already read this :
‘the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.’

but did not understand at all… If you have simpler words to explain it I will be very happy to hear it.