Hi,
I’m looking to have a flexible model loading/saving system to do different experiments. At this stage speed is not the primary concern.
When I use the recommended model = torch.load_state_dict(state_dict)
some of my model generation parameters are not saved in the corresponding torch.save(model.state_dict())
.
Concrete example. I have a network where in one case I want to output a classification, so the final layer in the model should be a sigmoid. I use a flag to indicate this (MyModel(use_sigmoid_output)
). This parameter is not automatically saved in the state dict.
The question is then, how to best save this auxilliary information that is relevant when loading? I want to avoid having to give these options at e.g. the command line. Where can I find the recommended best practises?
Simplified example
class MyModel:
# ...
def FromStateDict(model_state_dict):
input_size = model_state_dict['fc1.weight'].shape[1]
H1 = model_state_dict['fc2.weight'].shape[1]
output_size = model_state_dict['fc3.weight'].shape[0]
# How to best handle the fact that we have a final sigmoid layer
# (which is parameter less and thus does not end up in the state dict)?
# use_sigmoid_output = ..?
model = DeepNet(input_size, output_size, H1)
# model = DeepNet(input_size, output_size, H1)
# use_sigmoid_output=use_sigmoid_output)
model.load_state_dict(model_state_dict)
return model
model_state_dict = torch.load('model.pt')
model = MyModel.FromStateDict(model_state_dict)
Cheers,
Kim