I am curious of if the activations/inputs get quantized before computing with quantized weights in linear layer.
So I tried:
#%%
import torch
import torch.ao.quantization as quant
# Define a model
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(10, 5)
def forward(self, x):
print(f"Input dtype: {x.dtype}")
import pdb; pdb.set_trace()
x = self.fc(x)
print(f"Output dtype: {x.dtype}")
return x
#%%
# Instantiate model
model = MyModel()
# Apply dynamic quantization
quantized_model = quant.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
quantized_model
This gave the output:
MyModel(
(fc): DynamicQuantizedLinear(in_features=10, out_features=5, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)
After further searching for “DynamicQuantizedLinear”, I found torch.ops.quantized.linear_dynamic()
function in ao/nn/quantized/dynamic/modules/linear.py
as
class Linear(nnq.Linear):
...
def forward(self, x):
# Note that we can handle self.bias == None case.
if self._packed_params.dtype == torch.qint8:
if self.version is None or self.version < 4:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params)
else:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params, reduce_range=True)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_dynamic_fp16(
x, self._packed_params._packed_params)
else:
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
return Y.to(x.dtype)
So I further traced back torch.ops.quantized.linear_dynamic()
to the file include/torch/csrc/jit/passes/quantization/quantization_patterns.h
as in
std::string linear_dynamic = R"(
graph(%packed_params, %a):
%w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
%w_dequant = aten::dequantize(%w_quant)
%r = aten::linear(%a, %w_dequant, %b)
return (%r) )";
used in functions like dynamic_quantized_linear_pattern_and_replacements()
or dynamic_quant_fusion_pattern_and_replacements()
in your IR.
It seems that they are actually dequantizing the quantized weights instead of quantizing the fp32 inputs before the operation, but I thought it should be the other way around as explained in the doc:
Dynamic Quantization
The easiest method of quantization PyTorch supports is called dynamic quantization. This involves not just converting the weights to int8 - as happens in all quantization variants - but also converting the activations to int8 on the fly, just before doing the computation (hence “dynamic”). The computations will thus be performed using efficient int8 matrix multiplication and convolution implementations, resulting in faster compute.
Or does it actually have something to do with the hardware I am using and the different optimization techniques?