Apply `torch.autograd.Function` on layer weights without losing gradients

Hello,
I was trying to wrap the layers with a generic wrapper and then perform the layer operation whatever it was, like here:

class TorchModuleWrapper(Module):
    def __init__(self, torch_object: Module):
        super().__init__()
        # save the layer in the _wrapped_object
        self._wrapped_object = torch_object

    def __call__(self, x):
        # feedforwad the input on the layer with the new weights
        return self._wrapped_object(x)

The problem happened when I tried to apply the torch.autograd.Function on the layer weights inside the forward function since the function that I have to requires learnable parameters, these learnable parameters lost their gradients.

There’s one solution that I commented on in the example here that could be applying the operation between the input and the weights but it would not make the wrapper anymore.
So, is there any solution to apply the function on the weights without losing the gradient for the learnable parameters located outside the layer?

class PACT(Function):
    @staticmethod
    def forward(ctx, x, s):
        return torch.round(x * s) / s

    @staticmethod
    def backward(ctx, grad_output):
        dLdx = grad_output
        dLds = grad_output
        return dLdx, dLds

class TorchModuleWrapper(Module):
    def __init__(self, torch_object: Module):
        super().__init__()
        self.bits = 8
        torch_object.bias = None
        self._wrapped_object = torch_object
        self.scale = Parameter(Tensor(1))

    def __call__(self, x):
        weight = self._wrapped_object.weight
        quantized_weight = PACT.apply(weight, self.scale)
        self._wrapped_object.weight = Parameter(quantized_weight)
        return self._wrapped_object(x)
    
    def __repr__(self):
        return "Wrapped" + self._wrapped_object.__repr__()

Collab example with the unneeded solution:
https://colab.research.google.com/drive/1QfmSisFYmzOMPpDqOFDwT_1ptJc8Zpqx?usp=sharing