I really need help! After reading all the things I have found on internet, e.g.:
https://pytorch.org/docs/master/notes/serialization.html
Either its to complicate for beginners like me, either I found not working things.
As I understood there is two main approaches:
- This approach is recommanded, but to me its not convenient as I need to first creating an instance of the model before downloading (but when using the model in other notebook I dont necessarily want to download again all the information).
torch.save(model.state_dict(), os.path.join(path_,‘model_’+name_+‘.pt’))
#then to download it again:
model2 = Net()
model2.load_state_dict(torch.load(os.path.join(path_,‘model_’+name_+‘.pt’)))
model2.eval() #Sets in evaluation mode is needed for dropout and batchnorm
- Its not the recommanded one, but I dont understand why. Here I again need to download the class definition. But, then why wouldn’t I just save the model state_dict using torch.save(model.state_dict(), …) ? Also, I read that in this situation I should not put the model into eval mode. But this is only if I want to train it directly afterwards no? If I want to make more evaluation, test it on new data I can put the model under eval mode no? And then under train mode again for training I guess.
state = {‘model’:model,‘state_dict’:model.state_dict()}
torch.save(state, os.path.join(path_,‘model_’+name_+‘.tar’))
#and to load it again:
state = torch.load(os.path.join(path,‘model_’+name_+‘.tar’))
model2 = state[‘model’]
model2.load_state_dict(state[‘state_dict’])
I simply need to save the model in order to use it without needing to download all the class definition again Also, if I could also understand how to save it in a way that I can train it again it would be very useful too.