Hi,
I have a class DiscreteMLP(torch.nn.Module)
, in which I overwrite state_dict()
and load_state_dict()
to save also the class attribute self.theta
and load it again.
def state_dict(self, destination=None, prefix='', keep_vars=False):
""" Overrides state_dict() to save also theta value"""
original_dict = super().state_dict(destination, prefix, keep_vars)
original_dict[prefix+'theta'] = self.theta
return original_dict
def load_state_dict(self, state_dict, strict=True):
""" Overrides state_dict() to load also theta value"""
theta = state_dict.pop('theta')
self.set_theta(theta)
super().load_state_dict(state_dict, strict)
It works fine if I save and load this mlp only. However, when I integrate this mlp into a bigger model, I have a problem:
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
> self.__class__.__name__, "\n\t".join(error_msgs)))
E RuntimeError: Error(s) in loading state_dict for OptionsGraph:
E Unexpected key(s) in state_dict: "0.theta".
The bigger model is a model with multiple layers, where each layer is a mlp.
I could not find a proper pytorch tutorial, how to overwrite state_dict()
and load_state_dict()
that can handle this case. Could you please help?
Thanks a lot!