Save pytorch model without needing to declare again the class definition when download

I thought that using

state = {‘model’:model,‘state_dict’: model.state_dict()}
torch.save(state, os.path.join(path_,‘model.tar’))
#then downloading it like this:
state = torch.load(os.path.join(path_,‘model.tar’))
model2 = state[‘model’]
model2.load_state_dict(state[‘state_dict’])

wouldn’t require to define the class definition, and this was the reason why I was using this over simply saving only the state_dict.
But apparently its not the case as I get

AttributeError: Can’t get attribute ‘Net’ on <module ‘main’>

which get solved once adding the class definition.

Hence, I am wondering why if declaring the class definition is necessary before loading when the full model graph was saved via torch.save(model, …), wouldn’t I just save the model state_dict using torch.save(model.state_dict()) ? How are these two approach different?
Also what is the different between saving with ‘.pth.tar’ or simply ‘pt’?

Thank youvery much for any help

1 Like