.to(dtype) method with custom module and buffer of fixed precision

Hello,

I made a custom module that need high precision (float64) for few operations in the forward pass. Those operations involve few float64 tensors registered as buffer in the module as shown in the following snipped:

class custom_module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('foo_tensor', get_foo_tensor())
        self.register_buffer('float64_tensor',  get_float64_tensor())

    def forward(self, x):
        input_dtype = x.dtype
        x = some_regular_op(x, self.foo_tensor)

        # Avoid any autocast context messing around here
        with torch.cuda.amp.autocast(enabled=False):
            x = x.to(dtype = torch.float64)
            x = some_high_precision_op(x, self.float64_tensor)

        x = x.to(input_dtype)
        return x

As I registered my tensors, I can easily change the device of my module with .to(device). But now, I would like also to be able to use .to(dtype), but without affecting the precision of float64_tensor, is that possible in some way ?
Furthermore, my custom_module may be used as sub-module inside a model, so I want to be able to perform model.to(dtype) and still preserve the precision of this specific float64_tensor.

Any help would be great :slight_smile:

If you think that’s a good idea too, maybe I can ask for a freeze_dtype option for the register_buffer method as feature request.

So, the simpler workaround I found is to avoid registering my float64_tensor, and move it to the current device if needed during the forward pass. Plus I added hooks to the state_dict in order to correctly set or get this float64_tensor when saving or loading my module.

Snipped of code with workaround:

class custom_module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('foo_tensor', get_foo_tensor())
        float64_tensor = get_float64_tensor()

        # Register state_dict hooks to get/set float64_tensor when saving/loading the module
        def set_f64_tensor_from_state_dict_hook(state_dict, *args, **kwargs):
                self.float64_tensor = state_dict['float64_tensor']
        self._register_load_state_dict_pre_hook(set_f64_tensor_from_state_dict_hook)
        def update_state_dict_hook(self, state_dict, *args, **kwargs):
                state_dict['float64_tensor'] = self.float64_tensor
                return state_dict
        self._register_state_dict_hook(update_state_dict_hook)

    def forward(self, x):
        input_dtype = x.dtype
        x = some_regular_op(x, self.foo_tensor)

        # Change device of f64 tensor if needed
        self.float64_tensor = self.float64_tensor.to(x.device)

        # Avoid any autocast context messing around here
        with torch.cuda.amp.autocast(enabled=False):
            x = x.to(dtype = torch.float64)
            x = some_high_precision_op(x, self.float64_tensor)

        x = x.to(input_dtype)
        return x

Which is a good enough solution to me. But please, don’t hesitate to comment if you see any problem with this approach or if you have any smarter move.

I think you found a pretty good (and thorough) solution here, thank you for sharing!
You might double check if it works with distributed if that is something you need, and it seems inside the documented API. In particular, the obvious alternative – override Module.to – would also work but bear more risk from messing with PyTorch internals.

Best regards

Thomas

1 Like

Thanks for the feedback.

Good point about the distributed. I’m not using it, but I like to think of most use case (unfortunatly, I have only one gpu so i don’t know if I can actually perform the test).

I also though about overwriting Module.to but it would not work if my custom_module is registered as sub-module of model when running model.to(dtype), as only the to() method from model would be called and the not the one from custom_module. To do such approach I would have to overwrite the .to method from the float64_tensor itself, if I understood the source correctly (see the line return self._apply(convert) and how the function convert is defined).