My model class (NN with one hidden layer) has __init__(self, n_inputs, n_hidden, n_outputs)
, so when I instantiate a new model, I have to pass these arguments (number of nodes of each layer).
After training such a model, I save it
torch.save(model.state_dict(), 'filename')
Then to load it, according to documentation, I have first to instantiate a new model and then load to it the state_dict:
model_new = TheModelClass(n_inputs, n_hidden, n_outputs)
model_new.load_state_dict(torch.load('filename'))
But what if I don’t remember the arguments of the saved model?
If I am not mistaken, one way it can be done is to load
model_state_dict = torch.load('filename')
and then read the shapes of the weight and bias tensors of the saved model, contained in the model_state_dict
, and use them as arguments to instantiate a new model to which then load the state dict.
But is there a better way to do it?