I’m seeing unexpected behavior with post-training static quantization. My understanding is that after calibration, the instances of QuantStub
and DeQuantStub
are replaced by instances of torch.nn.quantized.Quantize
and torch.nn.quantized.DeQuantize
. What I’m seeing with torch 1.7.1 is that only DeQuantStub
is being replaced.
The output of the code below is:
UserWarning: Please use quant_min and quant_max to specify the range for observers.
reduce_range will be deprecated in a future release of PyTorch.
mwe.py:113: UserWarning: instance is type <class 'torch.quantization.stubs.QuantStub'>,
not <class 'torch.nn.quantized.modules.Quantize'>
Is my understanding correct? Am I missing a step?
import time
import os
import warnings
import torch
from torchvision import models
from torchvision import transforms
from torchvision import datasets
def get_loaders(data_dir, batch_size=256, num_workers=1, pin_memory=False):
traindir = os.path.join(data_dir, 'train')
valdir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory
)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(
valdir,
transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory
)
return train_loader, val_loader
def calibrate(qmodel, data_loader, num_batches=1):
for i, (input, target) in enumerate(data_loader):
qmodel(input)
if i == num_batches - 1:
break
def main(model_name, num_batches=1):
try:
func = getattr(models, model_name)
model = func(pretrained=True)
except AttributeError as ae:
msg = f'Invalid model name {model_name}? {ae}'
raise AttributeError(msg)
""" Wrap model so inputs and outputs are quantized. """
quantized_model = torch.quantization.QuantWrapper(model)
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
""" Calibrate the quantizers. """
static_quantized_model = torch.quantization.prepare(
quantized_model, inplace=False
)
train_loader, val_loader = get_loaders(
'/data/datasets/imagenet', pin_memory=False
)
calibrate(static_quantized_model, train_loader, num_batches=num_batches)
static_quantized_model = torch.quantization.convert(
static_quantized_model, inplace=False
)
""" After calling convert, quant and dequant should be instances of Quantize
and DeQuantize. Strangely, quant remains an instance of torch.quantization.QuantStub."""
def warn_type(instance, _type):
if not isinstance(instance, _type):
msg = f'instance is type {type(instance)}, not {_type}'
warnings.warn(msg)
warn_type(static_quantized_model.quant, torch.nn.quantized.Quantize)
warn_type(static_quantized_model.dequant, torch.nn.quantized.DeQuantize)
if __name__ == '__main__':
main('resnet50', num_batches=5)