I have a nested set of classes (each of type torch.nn.module
). I need to do some preprocessing before saving the weights of one of the nested classes. Is it possible to override the state_dict()
function so that I can insert the preprocessing in my custom implementation?
Sample code:
Class A(torch.nn.module):
def __init__(self):
super().__init__()
self.b1 = B1()
self.b2 = B2()
Class B1(torch.nn.module):
def __init__(self):
super().__init__()
self.var = torch.nn.Parameter(torch.Tensor((3, 5), dtype=float))
Class B2(torch.nn.module):
def __init__(self):
super().__init__()
self.var = torch.nn.Parameter(torch.Tensor((3, 5), dtype=float))
def state_dict():
# override the default state_dict
bool_var = self.var.bool().cpu().numpy()
state_dict1 = super.state_dict()
state_dict1.update({'var': bool_var})
return state_dict1
def load_state_dict(state_dict):
state_dict['var'] = state_dict['var'].float()
super.load_state_dict(state_dict)
return
Specifically, for one of the classes, I want to convert the variable to bool
before saving it and convert it back to float
while loading it. I can’t make the tensor as bool
by default since it needs to be learned as a float value.
The code I am working with is this. Here, they’ve hard-coded saving of such variable. I don’t want to do that. I want state_dict()
and load_state_dict()
to automatically take care of the conversion.
PS: While this discussion indicates that customizing is not possible, I am hoping PyTorch would have added it in the recent versions or there are some other ways of achieving this.