How does PyTorch implement Quantization?

Hello, I want to know how does PyTorch implement quantization without rewrite the layer definition.
Actually, because of certain reasons, I have to write a quantization framework with PyTorch for my project. Here is an simple example of my implementaion:

def quantization_decorator(func):
    @wraps(func)
    def wrapper(_, x):
        do_quantization(x)
        x = func(x)
        do_quantization(x)
        return x
    return wrapper

class Conv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    @quantization_decorator
    def forward(self, x):
        super().forward(x)

Since my quantization method is defined as a decorator,
ADWANTAGE: I can add it to forward to choose whether to quantize this module.
DISAWANTAGE: I have to rewrite each layer, e.g. class Conv2d in my example.

I want to know how does PyTorch implement quantization, so that I can optimize my framework (simpler, faster and more customizable).

Thank you.

  • My quantization method is different from PyTorch.
  • My custom layers have to be supported.

That’s why I have to write a new implementation.:see_no_evil:

Hi @Eta_C,

Please look at the flow of operation for quantization here: https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#post-training-static-quantization

The main steps for post training quantization are: 1) fusing of modules (e.g., conv; bn; relu => conv_bn_relu) 2) Observing tensor values to quantize tensors 3) Actual replacing of modules from float to quantized.

I know how to use PyTorch quantization tools.
I want to know how does PyTorch implement it.
OK, maybe I have to read some source code about it…

Yes. Reading source code on how it’s implemented is a good way. Here are some pointers for the code: Python related: https://github.com/pytorch/pytorch/tree/master/torch/quantization
C++ kernels: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu

You may also want to check the recent changes to these file and comments on related PRs to get some context.

Thank you.
I will take the time to read the source code carefully.:grinning:

I was trying to quantize a vgg model , following the steps given in https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html

I did the first step as follows:

import torch.nn as nn
#from .utils import load_state_dict_from_url

all = [
‘Cifar_VGG’, ‘cifar_rvgg11’, ‘cifar_rvgg11_bn’, ‘cifar_vgg11’, ‘cifar_vgg11_bn’, ‘cifar_vgg13’, ‘cifar_vgg13_bn’, ‘cifar_vgg16’, ‘cifar_vgg16_bn’,
‘cifar_vgg19_bn’, ‘cifar_vgg19’,
]

model_urls = {

'cifar_vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',

}

class Cifar_VGG(nn.Module):

def __init__(self, features, num_classes=1000, init_weights=True):
    super(Cifar_VGG, self).__init__()
    self.features = features
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.quant = QuantStub()
    self.dequant = DeQuantStub()
    self.classifier = nn.Sequential(
        nn.Linear(512, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(True),
        nn.Linear(512, num_classes),
    )
    if init_weights:
        self._initialize_weights()

def forward(self, x):
    x = self.quant(x)
    x = self.features(x)
    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    x = self.dequant(x)

    return x

def _initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)

def fuse_model(self):
    torch.quantization.fuse_modules(self, [['conv2d', 'BatchNorm2d', 'ReLU'],
                                        ],inplace=True)

def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == ‘M’:
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)

cfgs = {
‘RA’:[512,‘M’, 512, ‘M’, 512, 512, ‘M’, 512, 512, ‘M’, 512, 512, ‘M’],
‘A’: [64, ‘M’, 128, ‘M’, 256, 256, ‘M’, 512, 512, ‘M’, 512, 512, ‘M’],
‘B’: [64, 64, ‘M’, 128, 128, ‘M’, 256, 256, ‘M’, 512, 512, ‘M’, 512, 512, ‘M’],
‘D’: [64, 64, ‘M’, 128, 128, ‘M’, 256, 256, 256, ‘M’, 512, 512, 512, ‘M’, 512, 512, 512, ‘M’],
‘E’: [64, 64, ‘M’, 128, 128, ‘M’, 256, 256, 256, 256, ‘M’, 512, 512, 512, 512, ‘M’, 512, 512, 512, 512, ‘M’],
}

def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
if pretrained:
kwargs[‘init_weights’] = False
model = Cifar_VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model

def cifar_rvgg11_bn(pretrained=False, progress=True, **kwargs):
return _vgg(‘cifar_rvgg11’, ‘RA’, True, pretrained, progress, **kwargs)

net.fuse_model()
Gives the following error

‘Cifar_VGG’ object has no attribute ‘conv2d’

I think the fusion is not defined correctly for your model, please read the fusion section of the tutorial again and see if you can find the problem

1 Like

Now I am getting the following error
RuntimeError: No function is registered for schema aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) on tensor type QuantizedCPUTensorId; available functions are CPUTensorId, CUDATensorId, MkldnnCPUTensorId, VariableTensorId

How to merge batch_norm layer with conv layer?

Now i am trying ResNet18
def fuse_model(self):
for m in self.modules():
if type(m) == BasicBlock:
torch.quantization.fuse_modules(m, [[‘conv1’, ‘bn1’, ‘relu’],[‘conv2’, ‘bn2’]], inplace=True)
conv = getattr(m, ‘downsample’)
if(conv) :
torch.quantization.fuse_modules(conv, [‘0’, ‘1’], inplace=True)

But fusion is not giving any error.
But the infernce shows the error (which i posted above)

Got it !!!
I missed the following:
torch.quantization.fuse_modules(self, [‘conv1’, ‘bn1’, ‘relu’], inplace=True)

Thank you

1 Like