Is dynamic quantization in fact doing weight dequant instead of activation quant for `quantize_dynamic()`

I am curious of if the activations/inputs get quantized before computing with quantized weights in linear layer.

So I tried:

#%%
import torch
import torch.ao.quantization as quant

# Define a model
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(10, 5)

    def forward(self, x):
        print(f"Input dtype: {x.dtype}")
        import pdb; pdb.set_trace()
        x = self.fc(x)
        print(f"Output dtype: {x.dtype}")
        return x
#%%
# Instantiate model
model = MyModel()

# Apply dynamic quantization
quantized_model = quant.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
quantized_model

This gave the output:

MyModel(
  (fc): DynamicQuantizedLinear(in_features=10, out_features=5, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

After further searching for “DynamicQuantizedLinear”, I found torch.ops.quantized.linear_dynamic() function in ao/nn/quantized/dynamic/modules/linear.py
as

class Linear(nnq.Linear):
...

    def forward(self, x):
        # Note that we can handle self.bias == None case.
        if self._packed_params.dtype == torch.qint8:
            if self.version is None or self.version < 4:
                Y = torch.ops.quantized.linear_dynamic(
                    x, self._packed_params._packed_params)
            else:
                Y = torch.ops.quantized.linear_dynamic(
                    x, self._packed_params._packed_params, reduce_range=True)
        elif self._packed_params.dtype == torch.float16:
            Y = torch.ops.quantized.linear_dynamic_fp16(
                x, self._packed_params._packed_params)
        else:
            raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
        return Y.to(x.dtype)

So I further traced back torch.ops.quantized.linear_dynamic() to the file include/torch/csrc/jit/passes/quantization/quantization_patterns.h as in

  std::string linear_dynamic = R"(
graph(%packed_params, %a):
        %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
        %w_dequant = aten::dequantize(%w_quant)
        %r = aten::linear(%a, %w_dequant, %b)
        return (%r) )";

used in functions like dynamic_quantized_linear_pattern_and_replacements() or dynamic_quant_fusion_pattern_and_replacements() in your IR.

It seems that they are actually dequantizing the quantized weights instead of quantizing the fp32 inputs before the operation, but I thought it should be the other way around as explained in the doc:

Dynamic Quantization
The easiest method of quantization PyTorch supports is called dynamic quantization. This involves not just converting the weights to int8 - as happens in all quantization variants - but also converting the activations to int8 on the fly, just before doing the computation (hence “dynamic”). The computations will thus be performed using efficient int8 matrix multiplication and convolution implementations, resulting in faster compute.

Or does it actually have something to do with the hardware I am using and the different optimization techniques?

Hi @Chiao-Wei_Hsu , dynamic quantization is quantizing the activations and doing the matrix multiply in low precision. The “quant → dequant → high_precision_gemm” patterns you noticed should be a part of an intermediate representation, and we then lower that representation to a quantized kernel. If you want to see the final kernel, you could run the profiler and look at which kernel is getting executed - it should be the one with quantized activations and quantized gemm.