Storing and loading a model whose modules have extra attributes

Hello everyone.

I am trying to create an auto-encoder architecture for image segmentation using a framework for rotation invariance. The modules of this framework inherit from the torch.nn classes but have some extra attributes as well.
For the Conv2d module, for example, they don’t have only the weight and bias attributes but some extra as well. When I train my network and save it normally (with with torch.save(model.state_dict(), PATH)) i notice that when I load it with model.load_state_dict(torch.load(PATH)) and try to predict an image I get errors referring to the keys of these extra attributes. More specifically, I get: RuntimeError: Error(s) in loading state_dict :
Missing key(s) in state_dict: …
for every convolutional block there is no value assigned for the extra attributes in the dictionary.
I suppose when the dictionary is created only the standard attributes of the model’s components (pooling, conv. etc) are considered for adding to the dictionary.
Can I somehow change that for custom classes that inherit from torch.nn modules?

Thank you!

I would say that it should work always you add nn.parameters or buffers which are the ones tracked by the state_dict.
If you try to save other type of variables it will fail. I would recommend to pass them in the init function such that the constructor can recover them at the time of instantiating the class.

Another possibility is you are kind of hardcoding them not using the tools provided which properly register parameters.

1 Like

Thank you very much, I actually didn’t think of passing them in the init, struggling to integrate 2 models. I will try your suggestions.