Model serialisation/deserialisation


(Kim Albertsson) #1

Hi,

I’m looking to have a flexible model loading/saving system to do different experiments. At this stage speed is not the primary concern.

When I use the recommended model = torch.load_state_dict(state_dict) some of my model generation parameters are not saved in the corresponding torch.save(model.state_dict()).

Concrete example. I have a network where in one case I want to output a classification, so the final layer in the model should be a sigmoid. I use a flag to indicate this (MyModel(use_sigmoid_output)). This parameter is not automatically saved in the state dict.

The question is then, how to best save this auxilliary information that is relevant when loading? I want to avoid having to give these options at e.g. the command line. Where can I find the recommended best practises?

Simplified example

class MyModel:
    # ...
    def FromStateDict(model_state_dict):
        input_size = model_state_dict['fc1.weight'].shape[1]
        H1 = model_state_dict['fc2.weight'].shape[1]
        output_size = model_state_dict['fc3.weight'].shape[0]

        # How to best handle the fact that we have a final sigmoid layer
        # (which is parameter less and thus does not end up in the state dict)?
        # use_sigmoid_output = ..?

        model = DeepNet(input_size, output_size, H1)
        # model = DeepNet(input_size, output_size, H1)
        #                 use_sigmoid_output=use_sigmoid_output)
        model.load_state_dict(model_state_dict)

        return model

model_state_dict = torch.load('model.pt')
model = MyModel.FromStateDict(model_state_dict)

Cheers,
Kim


#2

It’s kind of an ugly hack, but you could register the flag as a buffer so that it will be stored in the state_dict.


(Kim Albertsson) #3

Hi,

Thanks, in this case, maybe it’s better to wrap the weights in a dictionary like:

meta_state_dict = {
    'flag1': value,
    # ... etc ...
    'weights': model.state_dict()
}

Cheers,
Kim


#4

That would be probably a better idea!
Similar to the last epoch number in case you would like to continue the training in another script.