Quantized Conv2d op bad output

I am trying to leverage Pytorch’s quantized ops functionality, but I notice that its accuracy tends to drop in some cases relative to other quantization frameworks. Inspecting further, I find that there are two cases that cause a drop in accuracy: If a MinMaxObserver has reduce_range=True or reduce_range=False. If reduce_range is false, then nn.quantized.Conv2d no longer behaves as expected. I’ve posted an example below where running an Imagenet sample against the first conv layer of Resnet50 differs in output compared to quantizing-dequantizing the weights/inputs and running the fp32 conv op. If reduce_range is true, this issue is no longer present but overall accuracy still drops and weight quantization error goes up (see script).

I know that similar issues have been brought up before. But I want to add that, it seems like this reduce_range parameter was introduced to resolve these issues. Yet, reducing the range of quantization by one bit will reduce accuracy, so reduce_range is not an all-round fix.

Any clue what’s causing the issues. Any fixes beyond reduce_range? Thanks

import torch
from torch import nn
from torchvision import transforms
import torchvision.models as tvm
from torchvision.datasets.folder import default_loader
from torch.nn.quantized.modules.utils import _quantize_weight

input_scale = 0.01865844801068306
input_zero_point = 114

weight_max = torch.Tensor([0.6586668252944946])
weight_min = torch.Tensor([-0.494000118970871])

output_scale = 0.04476073011755943
output_zero_point = 122

x_filename = 'ILSVRC2012_val_00000293.JPEG'
imsize = 224
val_transforms = transforms.Compose([transforms.ToTensor()])

model = tvm.resnet50(pretrained=True)
cmod = model.conv1

x_raw = default_loader(x_filename)
x = val_transforms(x_raw).unsqueeze(0)
X = torch.quantize_per_tensor(x, input_scale, input_zero_point, torch.quint8)

weight_observer = torch.quantization.observer.MinMaxObserver.with_args(
    qscheme=torch.per_tensor_affine,
    dtype=torch.qint8,
    reduce_range=False
)()
weight_observer.max_val = weight_max
weight_observer.min_val = weight_min

qmod = torch.nn.quantized.Conv2d(
    in_channels=cmod.in_channels, out_channels=cmod.out_channels, kernel_size=cmod.kernel_size, 
    stride=cmod.stride, padding=cmod.padding, padding_mode=cmod.padding_mode, dilation=cmod.dilation, groups=cmod.groups,
    bias=False
)
qweight = _quantize_weight(cmod.weight, weight_observer)
qmod.set_weight_bias(qweight, None)

qmod.scale = output_scale
qmod.zero_point = output_zero_point

y_native = qmod(X).dequantize()
y_simulated = torch.quantize_per_tensor(
    torch.nn.functional.conv2d(
        X.dequantize(), 
        qmod.weight().dequantize(), 
        None,
        qmod.stride, qmod.padding, qmod.dilation, qmod.groups
    ), 
    qmod.scale, qmod.zero_point, torch.quint8
).dequantize()

# Bad
print((y_native[0,33,18-3:18+3,23-3:23+3]-y_simulated[0,33,18-3:18+3,23-3:23+3]).abs()) 

# Good
print((y_native[0,32,18-3:18+3,23-3:23+3]-y_simulated[0,32,18-3:18+3,23-3:23+3]).abs())  

# Quantization error
print('Mean absolute difference', (qmod.weight().dequantize()-cmod.weight).abs().mean())

# Op error
print('Max absolute difference', (y_native-y_simulated).abs().max())  

The fbgemm kernels require reduce_range to be True, to prevent potential overflow. One could write a different kernel without this requirement, at the cost of reduced performance.

You are right that using one less bit is detrimental to accuracy. There are a couple of strategies to improve accuracy - using moving average observers, using per-channel weight observers instead of per tensor, etc.

The fbgemm kernels require reduce_range to be True, to prevent potential overflow.

Yeah this makes sense after looking at qconv.cpp, which suggest that vpmaddsubsw instructions are used in fbgemm. Actually these saturate not overflow, but they saturate to 16 bit intermediates rather than 32 bit. So is reduce_range meant to prevent this?

I also modified the script a little bit and found that in torch fx, the fbgemm option doesn’t quite reduce the range by 1 bit, so overflow/saturation is still observed:

import torch
from torch import nn
from torchvision import transforms
import torchvision.models as tvm
from torchvision.datasets.folder import default_loader
from torch.nn.quantized.modules.utils import _quantize_weight

def _fuse_fx(graph_module, fuse_custom_config_dict):
    return graph_module

import torch.quantization.quantize_fx
torch.quantization.quantize_fx._fuse_fx = _fuse_fx

backend_str = 'fbgemm'

x_filename = 'ILSVRC2012_val_00000293.JPEG'
imsize = 224
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)
val_transforms = transforms.Compose([
    transforms.Resize(imsize + 24),
    transforms.CenterCrop(imsize),
    transforms.ToTensor(),
    normalize
])

x_raw = default_loader(x_filename)
x_encoding = val_transforms(x_raw).unsqueeze(0)
torch.random.manual_seed(0)
x_test = x_encoding*2.0

model = tvm.resnet50(pretrained=True)
weight = model.conv1.weight.detach().clone()
import torch.quantization.quantize_fx as quantize_fx
qconfig_dict = {"": torch.quantization.get_default_qconfig(backend_str)}
model.eval()
model = quantize_fx.prepare_fx(model, qconfig_dict)
model(x_encoding)
model = quantize_fx.convert_fx(model)

qmod = model.conv1

X_test = torch.quantize_per_tensor(x_test, model.conv1_input_scale_0, model.conv1_input_zero_point_0, model.conv1_input_dtype_0)

y_native = qmod(X_test).dequantize()
y_simulated = torch.quantize_per_tensor(
    torch.nn.functional.conv2d(
        X_test.dequantize(),
        qmod.weight().dequantize(),
        None,
        qmod.stride, qmod.padding, qmod.dilation, qmod.groups
    ),
    qmod.scale, qmod.zero_point, torch.quint8
).dequantize()

if backend_str == 'fbgemm':
    print(qmod.weight().q_per_channel_scales()/((weight.max(dim=1).values.max(dim=1).values.max(dim=1).values - weight.min(dim=1).values.min(dim=1).values.min(dim=1).values)/255))
else:
    print(qmod.weight().q_scale()/((weight.max()-weight.min())/255))

# Bad
print((y_native[0,33,18-3:18+3,23-3:23+3]-y_simulated[0,33,18-3:18+3,23-3:23+3]).abs())

# Good
print((y_native[0,32,18-3:18+3,23-3:23+3]-y_simulated[0,32,18-3:18+3,23-3:23+3]).abs())

# Op error
print('Max absolute difference', (y_native-y_simulated).abs().max())

print(len(torch.where((y_native-y_simulated).abs() > 0.0005)[0]))

My output shows that all channels have a scale that is less than 2x the expected scale but greater than 1x the expected scale, which seems to indicate that the range is reduced by less than one bit. There are also a handful of overflow/saturation instances. On the other hand, if I manually set an observer with reduce_range=True, the scale is exactly 2x the expected scale (reduced by exactly one bit) and no overflow/saturation issues. So my question is, is the default behavior not to reduce range in torch fx?

looks like it depends on the qconfig i.e. backend you are working on:

in this case its not a MinMax observer, but reduce_range is set to true

Makes sense and thanks for pointing this out; it’s clear that neither backend relies on MinMax observers. Yet, in both cases, one still gets overflow/saturation issues. I wonder if this is due to the use of Histogram observers. To be specific, in either case, one gets a range < 2x the min-max range. So perhaps the issue is only resolved when range is >= 2x the min-max range?

Yeah, I tried to dig into the observer code to see if I could spot anything. I’m not seeing any usage of reduce range in the actual histogram observer but the underlying observer class has a bit here:

None of it looks out of the ordinary although I’m not super familiar with the observer part of the codebase. Having said there, there is:

which indicates it may be on the way out anyway.

1 Like