Inserting Unnecessary Fake Quants during QAT?

I’m confused about the way PyTorch inserts FakeQuant nodes during Quantization-Aware Training. In the documentation, the following image appears:

I understand the FakeQuant after W, since it will actually be quantized. But why insert the FakeQuant node after the input x, which necessitates the Quant/Dequant overhead at inference time? The Linear node is in fp32 anyway, so it seems this is pure overhead.

I’ve noticed this in practice: models exported to ONNX from PyTorch have zillions of Quant/Dequant pairs interspersed throughout, which seem unnecessary.

A) this is for dynamic quantization where you quantize the activations
B) the documentation you cite is for a lowering process, i.e. you do all the math needed to figure out what the quantization parameters and weights would be (prepare), then you get the graph to a special format (convert) which is not necessarily performant, then you use a lowering technique to recognize these special forms and swap them for performant ones. This lowering step would recognize the above graph as a dynamically quantized llinear op, where the actual graph is

x_fp → lowered_q_linear → out_fp
w_q--------^

1 Like

Thank you, this is exactly what I was missing. So this is a sort of intermediate representation that’s expected to be optimized down into quantized kernel.