Best practices for saving/loading dynamic pytorch models

Hi all,

I am currently working on a project for which I require to simulate the behaviour of some device using neural networks. I would like to have some flexibility in the architecture that is used for simulating the device, so I have created a function that allows me to do that inside the model:

class DeviceModel(nn.Module):
    def __init__(self, custom_dictionary):
        super(DeviceModel, self).__init__()
        self.build_my_structure(custom_dictionary) 

    def build_my_structure(self, custom_dictionary):
        hidden_sizes = custom_dictionary["hidden_sizes"]
        input_layer = nn.Linear(custom_dictionary["D_in"], hidden_sizes[0])
        activation = self._get_activation(custom_dictionary["activation"])   
        output_layer = nn.Linear(hidden_sizes[-1], custom_dictionary["D_out"])        
        modules = [input_layer, activation]

        hidden_layers = zip(hidden_sizes[:-1], hidden_sizes[1:])
        for h_1, h_2 in hidden_layers:
            hidden_layer = nn.Linear(h_1, h_2)
            modules.append(hidden_layer)
            modules.append(activation)

        modules.append(output_layer)
        self.model = nn.Sequential(*modules)

    def forward(self, x):
        return self.model(x)

What would be the best practice here to load the model? Should I save it directly using torch.save, or should I save the state_dictionary. If the second way is the best, I have to create the model from scratch using the original configs dictionary and then load the state_dict, how should can I store both of them together?

Perhaps, I am just using the library wrong, and there are other ways of creating a model dynamically? Or there is simply no support for this?

Thank you very much in advance. Any help would be much appreciated.

I would recommend to save and load the state_dict and to recreate the model using a stored config, as the former approach might break in various ways e.g. if you change the folder structure etc.

Yes, you would have to recreate the model. You can store the state_dict as well as other objects e.g. in a dict and use torch.save() on it as shown in the ImageNet example.