I’m looking to have a flexible model loading/saving system to do different experiments. At this stage speed is not the primary concern.
When I use the recommended
model = torch.load_state_dict(state_dict) some of my model generation parameters are not saved in the corresponding
Concrete example. I have a network where in one case I want to output a classification, so the final layer in the model should be a sigmoid. I use a flag to indicate this (
MyModel(use_sigmoid_output)). This parameter is not automatically saved in the state dict.
The question is then, how to best save this auxilliary information that is relevant when loading? I want to avoid having to give these options at e.g. the command line. Where can I find the recommended best practises?
class MyModel: # ... def FromStateDict(model_state_dict): input_size = model_state_dict['fc1.weight'].shape H1 = model_state_dict['fc2.weight'].shape output_size = model_state_dict['fc3.weight'].shape # How to best handle the fact that we have a final sigmoid layer # (which is parameter less and thus does not end up in the state dict)? # use_sigmoid_output = ..? model = DeepNet(input_size, output_size, H1) # model = DeepNet(input_size, output_size, H1) # use_sigmoid_output=use_sigmoid_output) model.load_state_dict(model_state_dict) return model model_state_dict = torch.load('model.pt') model = MyModel.FromStateDict(model_state_dict)