the quant and dequant stubs define where the dtype changes from fp32 to int8 so you can have sections that you want quantized and those you don’t. However you also need to set the qconfig appropriately, any modules with the qconfig set will be quantized (and/or their children).
i.e. in my example I have the quant and dequant surround the entire model since I apply the qconfig to the top level of the model (which during the quantization flow will apply that same qconfig to the rest of the model)
if you have parts of the model that you don’t want quantized, you need to make sure those don’t have a qconfig and that those were surrounded by a dequant and quant.
if you can make a toy repro that shows your issue, it would be helpful, Vasiliy was saying that its unclear which modules in your example have qconfig set and which do not.
as a clear example, here is how i would handle the BasicBlocks where i always avoid quantizing shortcut:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization import FakeQuantize, HistogramObserver, MinMaxObserver, QConfig
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, option="A"):
super(BasicBlock, self).__init__()
# **self.quant = torch.ao.quantization.QuantStub()** # I would put these in top level
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.q = torch.ao.quantization.QuantStub()
self.dq = torch.ao.quantization.DeQuantStub()
# eager mode quantization works poorly with functionals, need a module in order to do fusion
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
if option == "A":
self.shortcut = LambdaLayer(
lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)
)
elif option == "B":
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
# **self.dequant = torch.ao.quantization.DeQuantStub()** # I would put these in top level
def forward(self, x):
# **out = self.quant(x)** # I would put these in top level
out = self.relu1(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.dq(out)
out += self.shortcut(x) # surround shortcut with dq, q
out = self.q(out)
out = self.relu2(out)
# **out = self.dequant(out)** # I would put these in top level
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.quant = torch.ao.quantization.QuantStub()
self.bb1 = BasicBlock(1, 1, 1)
self.bb2 = BasicBlock(1, 1, 1)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.bb1(x)
x = self.bb2(x)
x = self.dequant(x)
return x
model_fp32 = Net().eval()
# B is bits
B = 4
##intB qconfig:
intB_act_fq = FakeQuantize.with_args(
observer=HistogramObserver,
quant_min=0,
quant_max=2 ** B - 1,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
)
intB_weight_fq = FakeQuantize.with_args(
observer=MinMaxObserver,
quant_min=-(2 ** B) // 2,
quant_max=(2 ** B) // 2 - 1,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False,
)
intB_qconfig = QConfig(activation=intB_act_fq, weight=intB_weight_fq)
# not sure what lambda layer is but i doubt we can quantize it
for subnet in [model_fp32.bb1, model_fp32.bb2]:
for name, module in subnet.named_children():
if "shortcut" not in name:
module.qconfig = intB_qconfig
model_fp32.quant.qconfig = intB_qconfig
model_fp32.dequant.qconfig = intB_qconfig
to_fuse = [
["bb1.conv1", "bb1.bn1", "bb1.relu1"],
["bb1.conv2", "bb1.bn2"],
["bb2.conv1", "bb2.bn1", "bb2.relu1"],
["bb2.conv2", "bb2.bn2"],
]
model_fp32_fused = torch.ao.quantization.fuse_modules_qat(model_fp32, to_fuse).train()
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused)
# calibrate model
model_fp32_prepared.apply(torch.ao.quantization.disable_fake_quant)
# calibration_code()
model_fp32_prepared(torch.randn(1, 1, 10, 10))
# prevents fake_quant from changing based on test code
model_fp32_prepared.apply(torch.ao.quantization.enable_fake_quant)
model_fp32_prepared.apply(torch.ao.quantization.disable_observer)
# test intB numbers
# test_code()
model_fp32_prepared(torch.randn(1, 1, 10, 10))
print(model_fp32_prepared)