Quantization Bug in Concatenation of Tensor

Using AO Quantization package in pytorch I am doing post training quantization. I have a concatenation layer which concatenates 2 Quantized tensors with their own scale and zero_point. Ideally when 2 tensors are concatenated a new scaling and zero_point should be calculated implicitly due to change in overall min and max values. However in pytorch, scale and zero_point for the concatenated tensor are taken that of the first tensor in the concat list. No new value is calculated.

Please suggest solutions. I am sharing the code and the outputs for easier comprehension. As you can see tensor Y and tensor tmp as in code have same scale and zero_point.

import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleConvNet(nn.Module):
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu2 = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self,x):
        x = self.quant(x)
        y = self.conv1(x)
        y = self.relu1(y)
        z = self.conv2(y)
        z = self.relu2(z)
        tmp = torch.cat([y,z],dim=1) #ISSUE HERE
        
        if y.dtype == torch.quint8:
            print('In Quantized Model')
            print('Tensor Y =>',y,'\n\n')
            print('Tensor Z =>',z,'\n\n')
            print('Concatenated Tensor =>',tmp,'\n\n')
        tmp = self.dequant(tmp)
        return tmp

def main():
    device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleConvNet().to(device)
    model.qconfig = torch.ao.quantization.default_qconfig#get_default_qconfig('x86')
    model = torch.ao.quantization.prepare(model)
    torch.manual_seed(0)
    inp = torch.rand((1,1,2,2))*20000-20
    inp = inp.to(device)
    output = model(inp)
    model = torch.ao.quantization.convert(model)
    model(inp)
    

if __name__ == "__main__":
    main()

PRINT OUTPUT
In Quantized Model
Tensor Y => tensor([[[[ 85.7967, 1816.0299],
[ 958.0630, 486.1812]]]], size=(1, 1, 2, 2), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=14.299448013305664,
zero_point=0)

Tensor Z => tensor([[[[ 0.0000, 293.9109],
[ 92.2744, 0.0000]]]], size=(1, 1, 2, 2), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=3.417569160461426,
zero_point=40)

Concatenated Tensor => tensor([[[[ 85.7967, 1816.0299],
[ 958.0630, 486.1812]],
[[ 0.0000, 300.2884],
[ 85.7967, 0.0000]]]], size=(1, 2, 2, 2), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=14.299448013305664,
zero_point=0)

the intended solution is to use float_functional modules to do the cat operation which then should get quantized correctly. Incidentally this is also how you’d add/mul…etc.

in general eager mode quantization can’t handle functional operations, they have to be converted to modules so they can be interact with.

in specific this method for cat: