Simple Custom QAT .backward() function

I have a model which I need to quantize down to a very low bit depth (3-5 bits) so I can not use the Quantize package which only supports 8 bit numbers.

My model can be called in two ways: Model(x) and Model.quantized_forward(x) which involves a not differentiable round operation. I want to implement simple Quantization aware L2 loss function of the form:

MSE(x) = || Model(x) - Target ||^2
Grad(MSE(x)) = Grad(Model(x))^T (Model.quantized_forward(x) - Target)

I am very confused about how to write this backwards function for something like this which uses the gradient of the model as well as a custom part. Can anyone help me get unstuck?

Thank you!