torch.quantization.QuantStub not replaced by convert?

I’m seeing unexpected behavior with post-training static quantization. My understanding is that after calibration, the instances of QuantStub and DeQuantStub are replaced by instances of torch.nn.quantized.Quantize and torch.nn.quantized.DeQuantize. What I’m seeing with torch 1.7.1 is that only DeQuantStub is being replaced.

The output of the code below is:

UserWarning: Please use quant_min and quant_max to specify the range for observers.
 reduce_range will be deprecated in a future release of PyTorch.
mwe.py:113: UserWarning: instance is type <class 'torch.quantization.stubs.QuantStub'>,
    not <class 'torch.nn.quantized.modules.Quantize'>

Is my understanding correct? Am I missing a step?

import time
import os
import warnings

import torch
from torchvision import models
from torchvision import transforms
from torchvision import datasets


def get_loaders(data_dir, batch_size=256, num_workers=1, pin_memory=False):
    traindir = os.path.join(data_dir, 'train')
    valdir = os.path.join(data_dir, 'val')

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose(
            [   
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            valdir,
            transforms.Compose(
                [   
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ]
            )
        ),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    return train_loader, val_loader


def calibrate(qmodel, data_loader, num_batches=1):
    for i, (input, target) in enumerate(data_loader):
        qmodel(input)
        if i == num_batches - 1:
            break


def main(model_name, num_batches=1):
    try:
        func = getattr(models, model_name)
        model = func(pretrained=True)
    except AttributeError as ae:
        msg = f'Invalid model name {model_name}? {ae}'
        raise AttributeError(msg)

    """ Wrap model so inputs and outputs are quantized. """
    quantized_model = torch.quantization.QuantWrapper(model)
    backend = "fbgemm"
    model.qconfig = torch.quantization.get_default_qconfig(backend)
    torch.backends.quantized.engine = backend

    """ Calibrate the quantizers. """
    static_quantized_model = torch.quantization.prepare(
        quantized_model, inplace=False
    )
    train_loader, val_loader = get_loaders(
        '/data/datasets/imagenet', pin_memory=False
    )
    calibrate(static_quantized_model, train_loader, num_batches=num_batches)
    static_quantized_model = torch.quantization.convert(
        static_quantized_model, inplace=False
    )

   """ After calling convert, quant and dequant should be instances of Quantize
    and DeQuantize. Strangely, quant remains an instance of torch.quantization.QuantStub."""
    def warn_type(instance, _type):
        if not isinstance(instance, _type):
            msg = f'instance is type {type(instance)}, not {_type}'
            warnings.warn(msg)
    warn_type(static_quantized_model.quant, torch.nn.quantized.Quantize)
    warn_type(static_quantized_model.dequant, torch.nn.quantized.DeQuantize)


if __name__ == '__main__':
    main('resnet50', num_batches=5)

I think the problem is that you are not setting the qconfig for the model that is being passed to prepare:
model.qconfig = torch.quantization.get_default_qconfig(backend)

1 Like

This is the correct answer. The stubs are properly replaced when I change

model.qconfig = torch.quantization.get_default_qconfig(backend)

to

quantized_model.qconfig = torch.quantization.get_default_qconfig(backend)

Thanks!