Does "De-Quantization" happens/executed after every Conv/Linear layer when using "torch.quantization.quantize_dynamic"

Hi,
I’m curious about how/why the value of the output tensor of the “dynamic-quantized” model can stay close to the output value of the original model.

In my opinion, during arithmetic calculation stage, if the weights of some layer is quantized(i.e. float-point weights is scaled by some factor and add some shift-bias(zero-point)), also the input is quantized, then the output of this layer is scaled and shifted (linearly). Then before we send the quantized output of this layer into the next non-linear layer, do we need to “de-quantized” the output tensor first? Since the value range/distribution of the output tensor is the scaled-and-shifted compared with the original model’s output. And if not “de-quantize” it, the non-linear layer’s(or the next coming quantized layer’s) processing will make the difference between the quantized model’s output and original model’s output larger and larger, because of the scale and shifting of the value accumulated.

So I’d like to know more details about how torch’s dynamic quantize works:

  1. Does each quantized layer do “de-quantization” to convert their output from int32(int16) into float32? Since the document says only the last output layer will do this.

  2. If the hidden/medium quantized layers’ output is not “de-quantized” during inference stage, will they be scaled and shifted in int data type? (To make sure the output value’s range or distribution stay close to the original model’s output)

  1. yes for dynamic quant the ouput is floating point, it can be float32 or bf16 or something else depending on the op. For static quantization your input and output are quantized and so you need a quant/dequant op at the start and end of any chunk of quantized operations.

dynamic
float → dynamic op → float

static
quantized dtype → static op → quantized dtype

  1. the static ops still do a type of dequant but just not to float. They do

quantized_dtype (8 bit) → quantized op → quantized dtype (32 bit) → quantized_dtype (8 bit)

you can find more info here: gemmlowp/doc/quantization.md at master · google/gemmlowp · GitHub

Hey, thanks for the answering!

As you said that “For static quantization your input and output are quantized and so you need a quant/dequant op at the start and end of any chunk of quantized operations”. But when I look the tutorial notebook from pytorch about the static quantization, I noticed the author only add the “QuantStub()” and “DeQuantStub()” at the start and the end of the whole neural network. Does torch implicitly adds the other quant/dequant ops for the rest of layers in the NN?

the example code of the tutorial notebook which confuses me is shown below, you can see the “QuantStub()” and “DeQuantStub()” only exists at the start and end of the neural network, according to the forward method.

class MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
        """
        MobileNet V2 main class
        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
        """
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
        # make it nn.Sequential
        self.features = nn.Sequential(*features)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        x = self.dequant(x)
        return x

that’s because its a toy example where you apply quantization to the whole model. You need one of those stubs before/after any quantized chunk. Not to put to fine a point on it: Eager mode quantization is ‘dumb’. It doesn’t know anything about the model as a whole, if you want to quantize a module, that module gets quantized but that’s it. It doesn’t know what came before, what dtype it will see at runtime…etc. If the conversion breaks stuff, that stuff is just broken. You need to add stubs to manually encode that holistic information. “Quant stub here because the next op is going to be quantized, after the quantized op we put a Dequant stub because the op after that is not quantized…etc”.

fx quantization on the other hand doesn’t need any stubs and implicitly inserts them. The difference is fx quantization traces the model so it has access to the graph and knows where quantized ops occur and the dtypes at each step. When it sees a non quantized dtype going into a quantized op (or vice versa), it adds a conversion operator. However, we’ve now added a traceability requirement to the flow which greatly constrains usability for anyone doing anything remotely non standard in their model. Have an if statement? Not traceable. For loop? Not traceable…etc. Its also a lot more complicated so when things go wrong, it can be harder to figure out why.

1 Like

Thanks for the detailed explanation!