Can `register_forward_pre_hook` modify module weight, and is modified weight learnable?

In Pytorch Doc It says The hook can modify the input, but there no doucument about module weight, my question is can register_forward_pre_hook modify module weight? and is the modified weight learnable?
I Noticed that weight_norm use this method to modify weight:

  1. delete original weight w_original name in _parameters
  2. register new weight w_new using register_paramerter
  3. using register_forward_pre_hook to register a hook function which set module w_new to ``w_original`

I want to know after these three steps if the w_original point to w_new, and w_new can be updated using SGD?

I try to simulate quantization using register_forward_pre_hook and change compute_weight in weight_norm
from

def compute_weight(self, module: Module) -> Any:
        g = getattr(module, self.name + '_g')
        v = getattr(module, self.name + '_v')
        return _weight_norm(v, g, self.dim)

to

# i want to simulate weight quantization Q10
def compute_weight(self, module: Module) -> Any:
        w = getattr(module, self.name + 'quantized')
        return torch.round(w * 2**10) / 2**10

but the loss remains constant, Could someone help me to figure it out?
Thanks a lot~

Isn’t torch.round non-differentiable?

torch.round is technically differentiable, but will return zeros, so it’s not really usefully differentiable:

x = torch.randn(5, requires_grad = True)
out = torch.round(x)
out.mean().backward()
print(x.grad)
# tensor([0., 0., 0., 0., 0.])

The derivative is defined in derivatives.yaml as:

- name: round(Tensor self) -> Tensor
  self: zeros_like(grad)
  result: auto_element_wise

I want to convert trained model’s weight from float32 to Q10 fix-point number, but the performance will reduce about 20%, could you give some advice on how to take quantization error(weight from float32 to Q10 fixed point number) into consideration?