Customize state_dict() and load_state_dict() pytorch

I have a nested set of classes (each of type torch.nn.module). I need to do some preprocessing before saving the weights of one of the nested classes. Is it possible to override the state_dict() function so that I can insert the preprocessing in my custom implementation?

Sample code:

Class A(torch.nn.module):
    def __init__(self):
        super().__init__()
        self.b1 = B1()
        self.b2 = B2()

Class B1(torch.nn.module):
    def __init__(self):
        super().__init__()
        self.var = torch.nn.Parameter(torch.Tensor((3, 5), dtype=float))

Class B2(torch.nn.module):
    def __init__(self):
        super().__init__()
        self.var = torch.nn.Parameter(torch.Tensor((3, 5), dtype=float))

    def state_dict():
        # override the default state_dict
        bool_var = self.var.bool().cpu().numpy()
        state_dict1 = super.state_dict()
        state_dict1.update({'var': bool_var})
        return state_dict1

    def load_state_dict(state_dict):
        state_dict['var'] = state_dict['var'].float()
        super.load_state_dict(state_dict)
        return

Specifically, for one of the classes, I want to convert the variable to bool before saving it and convert it back to float while loading it. I can’t make the tensor as bool by default since it needs to be learned as a float value.

The code I am working with is this. Here, they’ve hard-coded saving of such variable. I don’t want to do that. I want state_dict() and load_state_dict() to automatically take care of the conversion.

PS: While this discussion indicates that customizing is not possible, I am hoping PyTorch would have added it in the recent versions or there are some other ways of achieving this.

Based on this feature request the hooks, mentioned already in the linked post, are still internal and can break at any point. You could of course use these but might need to adapt your code since these internal APIs do not have any guarantee to be stable.