Quantization parameters in QuantizedConv2d

I would like to find where are the parameters quant_max, quant_min, min_val, max_val stored in QuantizedConv2d block. I was able to locate them using the following code in the observers

from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver
C, L = 3, 4
normal = torch.distributions.normal.Normal(0,1)
inputs = [normal.sample((C, L)), normal.sample((C, L))]

observers = [MinMaxObserver(), MovingAverageMinMaxObserver(), HistogramObserver()]
for obs in observers:
  for x in inputs: 
      obs(x) 
  print(obs.__class__.__name__, obs.calculate_qparams(), obs.quant_min, obs.quant_max, obs.min_val, obs.max_val)

The output being

MinMaxObserver (tensor([0.0113]), tensor([136], dtype=torch.int32)) 0 255 tensor(-1.5460) tensor(1.3433)
MovingAverageMinMaxObserver (tensor([0.0111]), tensor([134], dtype=torch.int32)) 0 255 tensor(-1.4766) tensor(1.3414)
HistogramObserver (tensor([0.0082]), tensor([143], dtype=torch.int32)) 0 255 tensor(-1.5460) tensor(1.3612)

However I am not able to locate the same after converting a Conv2d block into QuantizedConv2d block

class M(nn.Module):

    def __init__(self):
            super(M, self).__init__()
            self.quant = torch.quantization.QuantStub()
            self.conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1)            
            self.dequant = torch.quantization.DeQuantStub()
    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)  
        x = self.dequant(x)         
        return x

model_fp32 = M()
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare(model_fp32)
model_fp32_converted = torch.quantization.convert(model_fp32_prepared, inplace=True)

The weights of the model after passing the inputs through the model

model_fp32_converted(inputs)
model_fp32_converted.conv.weight()
tensor([[[[ 0.03294463083148002625, -0.13431271910667419434,
           -0.21540719270706176758],
          [ 0.14191532135009765625, -0.14191532135009765625,
           -0.24835182726383209229],
          [ 0.32184368371963500977, -0.14444953203201293945,
           -0.21287299692630767822]]]], size=(1, 1, 3, 3), dtype=torch.qint8,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.00253420230001211166], dtype=torch.float64),
       zero_point=tensor([0]), axis=0)

I would like to know

  • How to get quant_max, quant_min, min_val, max_val from this model_fp32_converted.conv block?

  • Point the location in the github code where these parameters from the observers get stored in the QuantizedConv2d block

1 Like

I don’t think the QuantConv2d needs these values (it needs the scale and zero point it should use for the various inputs and output but not the observed min/max), so I doubt that you can get them from it.

Best regards

Thomas

Hi Thomas, thanks for your response. If I understood the quantization process correctly I think there should be an observer block inside the QuantizeConv2D block which stores the min and max value from which the scale and zero_point is computed

Here is the link to where the observer is computing the scale and zero point

Hi Avishek,

Were you able figure out the answers to your questions? I would also like to understand how the Observer modules in Pytorch are being used to set the clipping range for the quantized blocks, especially in the convolution blocks. I want to experiment with different statistical observations if possible, so I would be grateful if you have any advice in that direction as well. Thank you.

Atharv

I am also finding a method to extract the parameter in a quantized model containing QuantizedConv2d,especially the quant_min ans quant_max used for requantization. After invoking

model_fp32_converted = torch.quantization.convert(model_fp32_prepared, inplace=True)

,the obervers are not exist in the model model_fp32_converted ,the process you described is happened only in training ,not in inference.
Can @HDCharles help us.

those values are generally hardcoded into the backend. see self._packed_params which holds the relevant information in an expected form which is then passed to the relevant op in the forward

these ops are understood to use int8 dtypes which has a pre-specified qmin/qmax and though you can set things up to use a differen qmin/qmax to get a different scale/zero_point, in the actual calculation, its going to clamp to whatever its set to use in the backend.

e.g.

its not passing the qmin/qmax to the requant kernel, its assumed

if i specify the activation configuration in qconfig (for example,[0~255] to [0,235]), how does the backend konw the quant_max and quant_min to clamp the activation durning inference,for examle, in quantized::conv2d or aten::quantize_per_tensor.

oh,i see, the activation of output tensor’s range is always [0,255],or [-128,127] ,which is only controlled by dtype,and qscheme is always by per_tensor_affine. The weight’s range can be configured by qconfig(for example ,reduce_range) and it will be fixed in the serialized model. So quant_max and quant_min are not associated with weight and activation. Those two parameters are implicitly represented by dtype.

Yep, there is however an option in eager mode quant to keep the qconfigs

https://pytorch.org/docs/stable/generated/torch.ao.quantization.convert.html

You could get the qmin/qmax from that