Using AO Quantization package in pytorch I am doing post training quantization. I have a concatenation layer which concatenates 2 Quantized tensors with their own scale and zero_point. Ideally when 2 tensors are concatenated a new scaling and zero_point should be calculated implicitly due to change in overall min and max values. However in pytorch, scale and zero_point for the concatenated tensor are taken that of the first tensor in the concat list. No new value is calculated.
Please suggest solutions. I am sharing the code and the outputs for easier comprehension. As you can see tensor Y and tensor tmp as in code have same scale and zero_point.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleConvNet(nn.Module):
def __init__(self):
super(SimpleConvNet, self).__init__()
self.quant = torch.quantization.QuantStub()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
self.relu2 = nn.ReLU()
self.dequant = torch.quantization.DeQuantStub()
def forward(self,x):
x = self.quant(x)
y = self.conv1(x)
y = self.relu1(y)
z = self.conv2(y)
z = self.relu2(z)
tmp = torch.cat([y,z],dim=1) #ISSUE HERE
if y.dtype == torch.quint8:
print('In Quantized Model')
print('Tensor Y =>',y,'\n\n')
print('Tensor Z =>',z,'\n\n')
print('Concatenated Tensor =>',tmp,'\n\n')
tmp = self.dequant(tmp)
return tmp
def main():
device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleConvNet().to(device)
model.qconfig = torch.ao.quantization.default_qconfig#get_default_qconfig('x86')
model = torch.ao.quantization.prepare(model)
torch.manual_seed(0)
inp = torch.rand((1,1,2,2))*20000-20
inp = inp.to(device)
output = model(inp)
model = torch.ao.quantization.convert(model)
model(inp)
if __name__ == "__main__":
main()
PRINT OUTPUT
In Quantized Model
Tensor Y => tensor([[[[ 85.7967, 1816.0299],
[ 958.0630, 486.1812]]]], size=(1, 1, 2, 2), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=14.299448013305664,
zero_point=0)
Tensor Z => tensor([[[[ 0.0000, 293.9109],
[ 92.2744, 0.0000]]]], size=(1, 1, 2, 2), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=3.417569160461426,
zero_point=40)
Concatenated Tensor => tensor([[[[ 85.7967, 1816.0299],
[ 958.0630, 486.1812]],
[[ 0.0000, 300.2884],
[ 85.7967, 0.0000]]]], size=(1, 2, 2, 2), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=14.299448013305664,
zero_point=0)