I am trying to leverage Pytorch’s quantized ops functionality, but I notice that its accuracy tends to drop in some cases relative to other quantization frameworks. Inspecting further, I find that there are two cases that cause a drop in accuracy: If a MinMaxObserver has reduce_range=True or reduce_range=False. If reduce_range is false, then nn.quantized.Conv2d no longer behaves as expected. I’ve posted an example below where running an Imagenet sample against the first conv layer of Resnet50 differs in output compared to quantizing-dequantizing the weights/inputs and running the fp32 conv op. If reduce_range is true, this issue is no longer present but overall accuracy still drops and weight quantization error goes up (see script).
I know that similar issues have been brought up before. But I want to add that, it seems like this reduce_range parameter was introduced to resolve these issues. Yet, reducing the range of quantization by one bit will reduce accuracy, so reduce_range is not an all-round fix.
Any clue what’s causing the issues. Any fixes beyond reduce_range? Thanks
import torch
from torch import nn
from torchvision import transforms
import torchvision.models as tvm
from torchvision.datasets.folder import default_loader
from torch.nn.quantized.modules.utils import _quantize_weight
input_scale = 0.01865844801068306
input_zero_point = 114
weight_max = torch.Tensor([0.6586668252944946])
weight_min = torch.Tensor([-0.494000118970871])
output_scale = 0.04476073011755943
output_zero_point = 122
x_filename = 'ILSVRC2012_val_00000293.JPEG'
imsize = 224
val_transforms = transforms.Compose([transforms.ToTensor()])
model = tvm.resnet50(pretrained=True)
cmod = model.conv1
x_raw = default_loader(x_filename)
x = val_transforms(x_raw).unsqueeze(0)
X = torch.quantize_per_tensor(x, input_scale, input_zero_point, torch.quint8)
weight_observer = torch.quantization.observer.MinMaxObserver.with_args(
qscheme=torch.per_tensor_affine,
dtype=torch.qint8,
reduce_range=False
)()
weight_observer.max_val = weight_max
weight_observer.min_val = weight_min
qmod = torch.nn.quantized.Conv2d(
in_channels=cmod.in_channels, out_channels=cmod.out_channels, kernel_size=cmod.kernel_size,
stride=cmod.stride, padding=cmod.padding, padding_mode=cmod.padding_mode, dilation=cmod.dilation, groups=cmod.groups,
bias=False
)
qweight = _quantize_weight(cmod.weight, weight_observer)
qmod.set_weight_bias(qweight, None)
qmod.scale = output_scale
qmod.zero_point = output_zero_point
y_native = qmod(X).dequantize()
y_simulated = torch.quantize_per_tensor(
torch.nn.functional.conv2d(
X.dequantize(),
qmod.weight().dequantize(),
None,
qmod.stride, qmod.padding, qmod.dilation, qmod.groups
),
qmod.scale, qmod.zero_point, torch.quint8
).dequantize()
# Bad
print((y_native[0,33,18-3:18+3,23-3:23+3]-y_simulated[0,33,18-3:18+3,23-3:23+3]).abs())
# Good
print((y_native[0,32,18-3:18+3,23-3:23+3]-y_simulated[0,32,18-3:18+3,23-3:23+3]).abs())
# Quantization error
print('Mean absolute difference', (qmod.weight().dequantize()-cmod.weight).abs().mean())
# Op error
print('Max absolute difference', (y_native-y_simulated).abs().max())