How to load a saved model without knowing the number of nodes beforehand

My model class (NN with one hidden layer) has __init__(self, n_inputs, n_hidden, n_outputs), so when I instantiate a new model, I have to pass these arguments (number of nodes of each layer).

After training such a model, I save it
torch.save(model.state_dict(), 'filename')

Then to load it, according to documentation, I have first to instantiate a new model and then load to it the state_dict:

model_new = TheModelClass(n_inputs, n_hidden, n_outputs)
model_new.load_state_dict(torch.load('filename'))

But what if I don’t remember the arguments of the saved model?

If I am not mistaken, one way it can be done is to load
model_state_dict = torch.load('filename')
and then read the shapes of the weight and bias tensors of the saved model, contained in the model_state_dict, and use them as arguments to instantiate a new model to which then load the state dict.

But is there a better way to do it?

I think you can save most python objects with torch.save. So at the very simplest you could call torch.save on the tuple (model.state_dict(),n_input,n_hidden,n_outputs). Probably make those integers class attributes while your’re at it. I use this all the time to save epochs elapsed with my state dicts.