Quantization/QAT causing jit.script to fail

Hello

I’m trying to do QAT -> Torchscript but am getting an error.

My model is

Click here
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub

def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, activation_fn=nn.ReLU):
    """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
    """
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
                  groups=in_channels, stride=stride, padding=padding),
        activation_fn(),
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
    )


class SSD(nn.Module):
    def __init__(self, num_classes: int, is_test=False, config=None, device=None, activation_fn=nn.ReLU):
        """Compose a SSD model using the given components.
        """
        super(SSD, self).__init__()
        self.base_channel = 16
        self.num_classes = num_classes
        self.is_test = is_test

        if device:
            self.device = device
        else:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        if config:
            self.center_variance = torch.tensor([config['center_variance']], device=device)
            self.size_variance = torch.tensor([config['size_variance']], device=device)
            self.priors = config['priors'].to(self.device)
        else:
            self.center_variance = torch.tensor([0.1], device=device)
            self.size_variance = torch.tensor([0.2], device=device)
        
        self.extras = nn.Sequential(
            nn.Conv2d(in_channels=self.base_channel * 16, out_channels=self.base_channel * 4, kernel_size=1),
            activation_fn(),
            SeperableConv2d(in_channels=self.base_channel * 4, out_channels=self.base_channel *
                               16, kernel_size=3, stride=2, padding=1, activation_fn=activation_fn),
            activation_fn()
        )

        self.regression_headers0 = SeperableConv2d(in_channels=self.base_channel * 4, out_channels=3 *
                                                   4, kernel_size=3, padding=1, activation_fn=activation_fn)
        self.regression_headers1 = SeperableConv2d(in_channels=self.base_channel * 8, out_channels=2 *
                                                   4, kernel_size=3, padding=1, activation_fn=activation_fn)
        self.regression_headers2 = SeperableConv2d(in_channels=self.base_channel * 16, out_channels=2 *
                                                   4, kernel_size=3, padding=1, activation_fn=activation_fn)
        self.regression_headers3 = nn.Conv2d(in_channels=self.base_channel * 16,
                                             out_channels=3 * 4, kernel_size=3, padding=1)

        self.classification_headers0 = SeperableConv2d(in_channels=self.base_channel * 4, out_channels=3 *
                                                       num_classes, kernel_size=3, padding=1, activation_fn=activation_fn)
        self.classification_headers1 = SeperableConv2d(in_channels=self.base_channel * 8, out_channels=2 *
                                                       num_classes, kernel_size=3, padding=1, activation_fn=activation_fn)
        self.classification_headers2 = SeperableConv2d(in_channels=self.base_channel * 16, out_channels=2 *
                                                       num_classes, kernel_size=3, padding=1, activation_fn=activation_fn)
        self.classification_headers3 = nn.Conv2d(in_channels=self.base_channel *
                                              16, out_channels=3 * num_classes, kernel_size=3, padding=1)

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                activation_fn()
            )

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                activation_fn(),

                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                activation_fn(),
            )
        self.backbone_chunk1 = nn.Sequential(
            conv_bn(3, self.base_channel, 2),  # 160*120
            conv_dw(self.base_channel, self.base_channel * 2, 1),
            conv_dw(self.base_channel * 2, self.base_channel * 2, 2),  # 80*60
            conv_dw(self.base_channel * 2, self.base_channel * 2, 1),
            conv_dw(self.base_channel * 2, self.base_channel * 4, 2),  # 40*30
            conv_dw(self.base_channel * 4, self.base_channel * 4, 1),
            conv_dw(self.base_channel * 4, self.base_channel * 4, 1),
            # BasicRFB(self.base_channel * 4, self.base_channel * 4, stride=1, scale=1.0, activation_fn=activation_fn)
        )
        self.backbone_chunk2 = nn.Sequential(
            conv_dw(self.base_channel * 4, self.base_channel * 8, 2),  # 20*15
            conv_dw(self.base_channel * 8, self.base_channel * 8, 1),
            conv_dw(self.base_channel * 8, self.base_channel * 8, 1),
        )
        self.backbone_chunk3 = nn.Sequential(
            conv_dw(self.base_channel * 8, self.base_channel * 16, 2),  # 10*8
            conv_dw(self.base_channel * 16, self.base_channel * 16, 1)
        )

        self.quant0 = QuantStub()
        self.quant1 = QuantStub()
        self.quant2 = QuantStub()
        self.quant3 = QuantStub()
        self.dequant = DeQuantStub()


    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        confidences = []
        locations = []

        x = self.quant0(x)
        for layer in self.backbone_chunk1:
            x = layer(x)
        x = self.dequant(x)
        confidence = self.classification_headers0(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers0(x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)
        confidences.append(confidence)
        locations.append(location)

        # x = self.quant1(x)
        for layer in self.backbone_chunk2:
            x = layer(x)
        # x = self.dequant(x)
        confidence = self.classification_headers1(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers1(x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)
        confidences.append(confidence)
        locations.append(location)

        # x = self.quant2(x)
        for layer in self.backbone_chunk3:
            x = layer(x)
        # x = self.dequant(x)
        confidence = self.classification_headers2(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers2(x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)
        confidences.append(confidence)
        locations.append(location)

        # x = self.quant3(x)
        x = self.extras(x)
        # x = self.dequant(x)
        confidence = self.classification_headers3(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers3(x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)
        confidences.append(confidence)
        locations.append(location)

        confidences = torch.cat(confidences, 1)
        locations = torch.cat(locations, 1)

        return confidences, locations

    def load(self, model):
        self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)

I do QAT -> torchscript and test it by running the code:

from ssd import SSD
...
net = SSD(num_classes=2, device=device, config=config)
...
net.load(trained_model_path)
net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(net, inplace=True)

for epoch in range(num_epochs):
    train() etc...

    net.eval()
    net.cpu()
    # convert to quantised
    quant_net = deepcopy(net)
    quant_net = torch.quantization.convert(quant_net, inplace=False)
    quant_net.save(os.path.join(args.checkpoint_folder,
                                          f"quantised-net.pth"))
            
    m = torch.jit.script(quant_net)
    m.cpu()
    dummy = torch.randn(1, 3, 480, 640).cpu().float()
    a = m.forward(dummy) # test to see if scripted module works

    torch.jit.save(m, os.path.join(args.checkpoint_folder, f"jit-net.pt"))

    net.to(DEVICE)

and I get the error on the line a = m.forward(dummy) :

Traceback (most recent call last):
  File "train_testt.py", line 360, in <module>
    a = m.forward(dummy)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/joel/Desktop/Ultra-Light-Fast-Generic-Face-Detector-1MB/minimod.py", line 124, in forward
            x = layer(x)
        x = self.dequant(x)
        confidence = self.classification_headers0(x)
                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
  File "/home/joel/anaconda3/envs/nightlytorch/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/home/joel/anaconda3/envs/nightlytorch/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/home/joel/anaconda3/envs/nightlytorch/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py", line 326, in forward
        if len(input.shape) != 4:
            raise ValueError("Input shape must be `(N, C, H, W)`!")
        return ops.quantized.conv2d(
               ~~~~~~~~~~~~~~~~~~~~ <--- HERE
            input, self._packed_params, self.scale, self.zero_point)
RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Named, Autograd, Profiler, Tracer, Autocast, Batched].

QuantizedCPU: registered at /opt/conda/conda-bld/pytorch_1594145889316/work/aten/src/ATen/native/quantized/cpu/qconv.cpp:736 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1594145889316/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1594145889316/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Autograd: fallthrough registered at /opt/conda/conda-bld/pytorch_1594145889316/work/aten/src/ATen/core/VariableFallbackKernel.cpp:31 [backend fallback]
Profiler: registered at /opt/conda/conda-bld/pytorch_1594145889316/work/torch/csrc/autograd/profiler.cpp:677 [backend fallback]
Tracer: fallthrough registered at /opt/conda/conda-bld/pytorch_1594145889316/work/torch/csrc/jit/frontend/tracer.cpp:960 [backend fallback]
Autocast: fallthrough registered at /opt/conda/conda-bld/pytorch_1594145889316/work/aten/src/ATen/autocast_mode.cpp:375 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1594145889316/work/aten/src/ATen/BatchingRegistrations.cpp:149 [backend fallback]

This error does not occur if I remove all QAT/quantization lines and just jit.script the original model. The same error still occurs if I remove all Quant/DeQuantStubs in the model.
Does anyone know why this error occurs?
Could I also ask whether the commented out QuantStubs/DeStubs in the model are correctly placed?

Thank you!

this means the QuantStub/DeQuantStub is not placed correctly in the model, and the input of quantized::conv2d is not quantized yet, you can look at the model and see if you have a missing QuantStub before conv2d module.

looking at the code most likely it’s here:

x = self.quant0(x)
for layer in self.backbone_chunk1:
            x = layer(x)

only the first x is quantized in this case, instead you should have a quant x for each activation in the loop:

x = self.quant0(x)
for i, layer in enumerate(self.backbone_chunk1):
            x = layer(x)
            x = self.quants[i](x)

and define a list of quantstub instances with same length of self.backbone_chunk1 in init

1 Like

Hello. Thanks for the reply!

I tried what you suggested but still get a slightly different error

RuntimeError: Could not run 'quantized::conv2d' with arguments from the 'CPUTensorId' backend. 'quantized::conv2d' is only available for these backends: [QuantizedCPUTensorId].

I’m wondering perhaps how to solve the problem on a simpler model which has the same error/issue

import torch
from torch import nn, optim
from torch.quantization import QuantStub, DeQuantStub
from copy import deepcopy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone0 = nn.Sequential(
            nn.Conv2d(3, 1, 1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(),
        )
        self.backbone1 = nn.Sequential(
            nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2),
            nn.AvgPool2d(14),
            nn.Sigmoid(),
        )
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
    def forward(self, x):
        x = self.quant(x)
        x = self.backbone0(x)
        x = self.dequant(x)
        x = self.backbone1(x)
       
        return x

model = Model()

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)

optimizer = optim.Adam(model.parameters(), lr=1)

model.to(device)
print(model)

criterion = nn.BCELoss()

for epoch in range(10):
    model.train()

    inputs = torch.rand(2, 3, 28, 28)
    labels = torch.FloatTensor([[1, 1], [0, 0]])

    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    loss = criterion(outputs.view(2, 2), labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 2:
        model.apply(torch.quantization.disable_observer)

    if epoch >= 3:
        model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    quant_model = deepcopy(model)
    quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)

    with torch.no_grad():
        out = quant_model(torch.rand(1, 3, 28, 28))

I tried to prepare_qat only backbone0 as well but got the same error:

model.backbone0.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model.backbone0, inplace=True)

What is the correct way to place QuantStubs in the forward() such that only backbone0 is quantised?

Managed to get it to look correct (from looking at print(quant_model)) and not error out by only preparing qat on backbone0, and inserting the Quant/DeQuantStub into the nn.Sequential itself

...
    self.backbone0 = nn.Sequential(
            QuantStub(),
            nn.Conv2d(3, 1, 1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            DeQuantStub()
        )
...
model.backbone0.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model.backbone0, inplace=True)

I notice that when I forward zeros through it, it has a random chance of having either the same output as the non quant-converted model or drastically different outputs.
Is this the wrong way to get backbone0 quantised?

I notice that when I forward zeros through it, it has a random chance of having either the same output as the non quant-converted model or drastically different outputs.

is this after qat?

Yes

The code now
import torch
from torch import nn, optim
from torch.quantization import QuantStub, DeQuantStub
from copy import deepcopy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone0 = nn.Sequential(
            QuantStub(),
            nn.Conv2d(3, 1, 1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            DeQuantStub(),
        )
        self.backbone1 = nn.Sequential(
            nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2),
            nn.AvgPool2d(14),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.backbone0(x)
        x = self.backbone1(x)
        return x

model = Model()

model.backbone0.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model.backbone0, inplace=True)

optimizer = optim.Adam(model.parameters(), lr=1)
model.to(device)
print(model)



criterion = nn.BCELoss()

for epoch in range(10):
    print('EPOCH', epoch)
    model.train()

    inputs = torch.rand(2, 3, 28, 28)
    labels = torch.FloatTensor([[1, 1], [0, 0]])

    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    loss = criterion(outputs.view(2, 2), labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 2:
        model.apply(torch.quantization.disable_observer)
        pass

    if epoch >= 3:
        model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    quant_model = deepcopy(model)
    quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)

    with torch.no_grad():
        inp = torch.zeros([1, 3, 28, 28], device='cpu')
        model.eval().cpu()
        quant_model.eval().cpu()
        qout = quant_model.forward(inp)
        out = model.forward(inp)
        print(qout.view(2).tolist())
        print(out.view(2).tolist())

        model.to(device)

I printed the state_dicts of the not converted and converted models if the output from forwarding zeros diverged/stayed the same:

diverged
# Forward zeros output
[0.28445518016815186, 0.3933817744255066] #not converted
[0.07719799876213074, 0.6743956208229065] #converted

<================ NON-CONVERTED MODEL ====================>
OrderedDict([('backbone0.0.activation_post_process.scale', tensor([0.0077])), ('backbone0.0.activation_post_process.zero_point', tensor([0])), ('backbone0.0.activation_post_process.activation_post_process.min_val', tensor(0.0006)), ('backbone0.0.activation_post_process.activation_post_process.max_val', tensor(0.9801)), ('backbone0.1.weight', tensor([[[[-4.4932]],

         [[ 3.9661]],

         [[-4.1858]]]])), ('backbone0.1.activation_post_process.scale', tensor([0.0039])), ('backbone0.1.activation_post_process.zero_point', tensor([127])), ('backbone0.1.activation_post_process.activation_post_process.min_val', tensor(-0.4932)), ('backbone0.1.activation_post_process.activation_post_process.max_val', tensor(0.0016)), ('backbone0.1.weight_fake_quant.scale', tensor([0.0028])), ('backbone0.1.weight_fake_quant.zero_point', tensor([0])), ('backbone0.1.weight_fake_quant.activation_post_process.min_vals', tensor([-0.3583])), ('backbone0.1.weight_fake_quant.activation_post_process.max_vals', tensor([0.0458])), ('backbone0.2.weight', tensor([4.2045])), ('backbone0.2.bias', tensor([-0.3767])), ('backbone0.2.running_mean', tensor([-0.1198])), ('backbone0.2.running_var', tensor([0.3677])), ('backbone0.2.num_batches_tracked', tensor(10)), ('backbone0.2.activation_post_process.scale', tensor([0.0341])), ('backbone0.2.activation_post_process.zero_point', tensor([59])), ('backbone0.2.activation_post_process.activation_post_process.min_val', tensor(-2.0061)), ('backbone0.2.activation_post_process.activation_post_process.max_val', tensor(2.3278)), ('backbone0.3.activation_post_process.scale', tensor([0.0184])), ('backbone0.3.activation_post_process.zero_point', tensor([0])), ('backbone0.3.activation_post_process.activation_post_process.min_val', tensor(0.0400)), ('backbone0.3.activation_post_process.activation_post_process.max_val', tensor(2.3427)), ('backbone1.0.weight', tensor([[[[-3.8434, -4.1207,  4.9514],
          [ 5.1071, -4.7582, -3.8411],
          [-3.6406, -4.1518, -1.5902]]],


        [[[-4.2797, -4.4012,  1.1095],
          [-5.2368,  5.8240,  0.5995],
          [-3.5678, -3.8644,  1.4833]]]])), ('backbone1.1.weight', tensor([-2.8446,  2.8394])), ('backbone1.1.bias', tensor([ 0.2494, -0.2510])), ('backbone1.1.running_mean', tensor([-12.0386,  -4.0614])), ('backbone1.1.running_var', tensor([157.3076, 138.7101])), ('backbone1.1.num_batches_tracked', tensor(10))])

<============== QUANT CONVERTED MODEL ===============>
OrderedDict([('backbone0.0.scale', tensor([0.0077])), ('backbone0.0.zero_point', tensor([0])), ('backbone0.1.weight', tensor([[[[-0.3597]],

         [[ 0.3569]],

         [[-0.3597]]]], size=(1, 3, 1, 1), dtype=torch.qint8,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.0028], dtype=torch.float64), zero_point=tensor([0]),
       axis=0)), ('backbone0.1.scale', tensor(0.0039)), ('backbone0.1.zero_point', tensor(127)), ('backbone0.1.bias', None), ('backbone0.2.weight', tensor([1.])), ('backbone0.2.bias', tensor([0.])), ('backbone0.2.running_mean', tensor([0.])), ('backbone0.2.running_var', tensor([1.])), ('backbone0.2.num_batches_tracked', tensor(0)), ('backbone1.0.weight', tensor([[[[-3.8434, -4.1207,  4.9514],
          [ 5.1071, -4.7582, -3.8411],
          [-3.6406, -4.1518, -1.5902]]],


        [[[-4.2797, -4.4012,  1.1095],
          [-5.2368,  5.8240,  0.5995],
          [-3.5678, -3.8644,  1.4833]]]])), ('backbone1.1.weight', tensor([-2.8446,  2.8394])), ('backbone1.1.bias', tensor([ 0.2494, -0.2510])), ('backbone1.1.running_mean', tensor([-12.0386,  -4.0614])), ('backbone1.1.running_var', tensor([157.3076, 138.7101])), ('backbone1.1.num_batches_tracked', tensor(10))])
still the same
# Forward zeros output
[0.4619605243206024, 0.3693790137767792] #not converted
[0.4619605243206024, 0.3693790137767792] #converted

<================ NON-CONVERTED MODEL ====================>
OrderedDict([('backbone0.0.activation_post_process.scale', tensor([0.0077])), ('backbone0.0.activation_post_process.zero_point', tensor([0])), ('backbone0.0.activation_post_process.activation_post_process.min_val', tensor(8.7249e-05)), ('backbone0.0.activation_post_process.activation_post_process.max_val', tensor(0.9802)), ('backbone0.1.weight', tensor([[[[4.5793]],

         [[3.9695]],

         [[4.1871]]]])), ('backbone0.1.activation_post_process.scale', tensor([0.0046])), ('backbone0.1.activation_post_process.zero_point', tensor([42])), ('backbone0.1.activation_post_process.activation_post_process.min_val', tensor(-0.1939)), ('backbone0.1.activation_post_process.activation_post_process.max_val', tensor(0.3904)), ('backbone0.1.weight_fake_quant.scale', tensor([0.0035])), ('backbone0.1.weight_fake_quant.zero_point', tensor([0])), ('backbone0.1.weight_fake_quant.activation_post_process.min_vals', tensor([-0.1654])), ('backbone0.1.weight_fake_quant.activation_post_process.max_vals', tensor([0.4444])), ('backbone0.2.weight', tensor([3.7733])), ('backbone0.2.bias', tensor([-3.2874])), ('backbone0.2.running_mean', tensor([0.4043])), ('backbone0.2.running_var', tensor([0.3758])), ('backbone0.2.num_batches_tracked', tensor(10)), ('backbone0.2.activation_post_process.scale', tensor([0.0358])), ('backbone0.2.activation_post_process.zero_point', tensor([65])), ('backbone0.2.activation_post_process.activation_post_process.min_val', tensor(-2.3070)), ('backbone0.2.activation_post_process.activation_post_process.max_val', tensor(2.2333)), ('backbone0.3.activation_post_process.scale', tensor([0.0179])), ('backbone0.3.activation_post_process.zero_point', tensor([0])), ('backbone0.3.activation_post_process.activation_post_process.min_val', tensor(0.)), ('backbone0.3.activation_post_process.activation_post_process.max_val', tensor(2.2680)), ('backbone1.0.weight', tensor([[[[ 4.3533, -3.5816,  5.0651],
          [-4.1010, -3.6161, -4.3417],
          [ 4.0427, -4.3517,  4.1981]]],


        [[[ 2.5564, -3.3695,  4.0380],
          [-3.9976, -7.2543, -4.0428],
          [ 3.7343, -3.5447,  2.7283]]]])), ('backbone1.1.weight', tensor([-0.5695, -2.9817])), ('backbone1.1.bias', tensor([-0.1784, -0.1795])), ('backbone1.1.running_mean', tensor([ 0.1879, -0.5091])), ('backbone1.1.running_var', tensor([16.9781, 18.2449])), ('backbone1.1.num_batches_tracked', tensor(10))])

<============== QUANT CONVERTED MODEL ===============>
OrderedDict([('backbone0.0.scale', tensor([0.0077])), ('backbone0.0.zero_point', tensor([0])), ('backbone0.1.weight', tensor([[[[0.4426]],

         [[0.4426]],

         [[0.4426]]]], size=(1, 3, 1, 1), dtype=torch.qint8,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.0035], dtype=torch.float64), zero_point=tensor([0]),
       axis=0)), ('backbone0.1.scale', tensor(0.0046)), ('backbone0.1.zero_point', tensor(42)), ('backbone0.1.bias', None), ('backbone0.2.weight', tensor([1.])), ('backbone0.2.bias', tensor([0.])), ('backbone0.2.running_mean', tensor([0.])), ('backbone0.2.running_var', tensor([1.])), ('backbone0.2.num_batches_tracked', tensor(0)), ('backbone1.0.weight', tensor([[[[ 4.3533, -3.5816,  5.0651],
          [-4.1010, -3.6161, -4.3417],
          [ 4.0427, -4.3517,  4.1981]]],


        [[[ 2.5564, -3.3695,  4.0380],
          [-3.9976, -7.2543, -4.0428],
          [ 3.7343, -3.5447,  2.7283]]]])), ('backbone1.1.weight', tensor([-0.5695, -2.9817])), ('backbone1.1.bias', tensor([-0.1784, -0.1795])), ('backbone1.1.running_mean', tensor([ 0.1879, -0.5091])), ('backbone1.1.running_var', tensor([16.9781, 18.2449])), ('backbone1.1.num_batches_tracked', tensor(10))])

can you paste the code for printing as well?

Hello,

Here’s the current code where I turn on QAT only near the end of training, and still has the same issue:

import torch
from torch import nn, optim
from torch.quantization import QuantStub, DeQuantStub
from copy import deepcopy

print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone0 = nn.Sequential(
            QuantStub(),
            nn.Conv2d(3, 1, 1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            DeQuantStub(),
        )
        self.backbone1 = nn.Sequential(
            nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2),
            nn.MaxPool2d(14),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.backbone0(x)
        x = self.backbone1(x)
        return x

model = Model()

# torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)

# model.backbone0.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# torch.quantization.prepare_qat(model.backbone0, inplace=True)

optimizer = optim.Adam(model.parameters(), lr=1)
model.to(device)

criterion = nn.BCELoss()
for epoch in range(1000):
    # print('EPOCH', epoch)
    model.train()

    inputs = torch.rand(2, 3, 28, 28)
    labels = torch.FloatTensor([[1, 1], [0, 0]])

    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    loss = criterion(outputs.view(2, 2), labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch == 945: # turn on qat
        model.to('cpu')
        model.backbone0.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        torch.quantization.prepare_qat(model.backbone0, inplace=True)
        model.to(device)

    if epoch >= 950:
        model.apply(torch.quantization.disable_observer)
        pass

    if epoch >= 950:
        model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    if epoch == 999:
        # print('MODEL', model)
        quant_model = deepcopy(model)
        quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)

        with torch.no_grad():
            inp = torch.zeros([1, 3, 28, 28], device='cpu')
            model.eval().cpu()
            quant_model.eval().cpu()

            qout = quant_model.forward(inp)
            out = model.forward(inp)

            print(f"<============== EPOCH {epoch} ===============>")
            print(out.view(2).tolist(), "#not converted")
            print(qout.view(2).tolist(), "#quant converted")
            print(f"<============== NOT CONVERTED MODEL ===============>")
            print(model.state_dict())
            print(f"<============== QUANT CONVERTED MODEL ===============>")
            print(quant_model.state_dict())

            model.to(device)

Looks like in the case when they diverged the state_dict still matches? did you enable fake quantization when you compare the result?

Yes. It was enabled although I’m wondering if enabling it only on a sequential container layer in the module as I did is incorrect?

Updating to torch nightly from torch 1.5.1 fixed the issue! (did not try 1.6)

1 Like