Model serialisation/deserialisation


I’m looking to have a flexible model loading/saving system to do different experiments. At this stage speed is not the primary concern.

When I use the recommended model = torch.load_state_dict(state_dict) some of my model generation parameters are not saved in the corresponding

Concrete example. I have a network where in one case I want to output a classification, so the final layer in the model should be a sigmoid. I use a flag to indicate this (MyModel(use_sigmoid_output)). This parameter is not automatically saved in the state dict.

The question is then, how to best save this auxilliary information that is relevant when loading? I want to avoid having to give these options at e.g. the command line. Where can I find the recommended best practises?

Simplified example

class MyModel:
    # ...
    def FromStateDict(model_state_dict):
        input_size = model_state_dict['fc1.weight'].shape[1]
        H1 = model_state_dict['fc2.weight'].shape[1]
        output_size = model_state_dict['fc3.weight'].shape[0]

        # How to best handle the fact that we have a final sigmoid layer
        # (which is parameter less and thus does not end up in the state dict)?
        # use_sigmoid_output = ..?

        model = DeepNet(input_size, output_size, H1)
        # model = DeepNet(input_size, output_size, H1)
        #                 use_sigmoid_output=use_sigmoid_output)

        return model

model_state_dict = torch.load('')
model = MyModel.FromStateDict(model_state_dict)


It’s kind of an ugly hack, but you could register the flag as a buffer so that it will be stored in the state_dict.


Thanks, in this case, maybe it’s better to wrap the weights in a dictionary like:

meta_state_dict = {
    'flag1': value,
    # ... etc ...
    'weights': model.state_dict()


1 Like

That would be probably a better idea!
Similar to the last epoch number in case you would like to continue the training in another script.

1 Like