Can we calculate gradients for quantized models?

fake quantization simulates quantization but uses high precision data types

so for example imagine if you were trying to quantize to integers.

mathematically a quantized linear op would be:

X = round(X).to(int)
weight = round(weight).to(int)
out = X*weight

whereas a fake_quantized linear would be
X = round(X).to(fp32)
weight = round(weight).to(fp32)
out = X*weight

In practice quantized weights are stored as quantized tensors which are difficult to interact with in order to make them able to perform quantized operations quickly.

fake_quantized weights are stored as floats so you can interact with them easily in order to do gradient updates.

most quantized ops do not have a gradient function so you won’t be able to take a gradient of it. Note: even quantization aware training doesn’t really give gradients of the model, see: [1308.3432] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation