How to save a model attribute in with state_dict()

Hi,

I have a class DiscreteMLP(torch.nn.Module), in which I overwrite state_dict() and load_state_dict() to save also the class attribute self.theta and load it again.


    def state_dict(self, destination=None, prefix='', keep_vars=False):
        """ Overrides state_dict() to save also theta value"""
        original_dict = super().state_dict(destination, prefix, keep_vars)
        original_dict[prefix+'theta'] = self.theta
        return original_dict

    def load_state_dict(self, state_dict, strict=True):
        """ Overrides state_dict() to load also theta value"""
        theta = state_dict.pop('theta')
        self.set_theta(theta)
        super().load_state_dict(state_dict, strict)

It works fine if I save and load this mlp only. However, when I integrate this mlp into a bigger model, I have a problem:

if len(error_msgs) > 0:
            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
>                              self.__class__.__name__, "\n\t".join(error_msgs)))
E           RuntimeError: Error(s) in loading state_dict for OptionsGraph:
E           	Unexpected key(s) in state_dict: "0.theta".

The bigger model is a model with multiple layers, where each layer is a mlp.

I could not find a proper pytorch tutorial, how to overwrite state_dict() and load_state_dict() that can handle this case. Could you please help?

Thanks a lot!

1 Like

If I were you, I set theta by register_buffer in the custom mlp like running_mean in batch norm (here is the link of that part).
In this case, it’s not necessary to overwrite load_state_dict and state_dict.

2 Likes

Thank you, that’s exactly what I am looking for. A clean way to register theta.

self.register_buffer('theta', torch.tensor([theta]))
1 Like