Changing Qconfig to set datatype to int8

I want to do QAT using torch.fx. My code is here:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.models import resnet18
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
from torch.ao.quantization.qconfig_mapping import _get_symmetric_qnnpack_qat_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx

from torch.ao.quantization import QConfig, default_observer, default_per_channel_weight_observer
from torch.ao.quantization.observer import (
    MinMaxObserver,
    MovingAverageMinMaxObserver,
    MovingAveragePerChannelMinMaxObserver,
    _PartialWrapper,
    default_fixed_qparams_range_0to1_observer,
    default_fixed_qparams_range_neg1to1_observer,
    default_weight_observer,
    default_placeholder_observer,
        )
from torch.ao.quantization.fake_quantize import (
            FusedMovingAvgObsFakeQuantize,
            default_weight_fake_quant,
            FixedQParamsFakeQuantize,
        )
import onnx
import copy
from torch.ao.quantization.backend_config import (
            BackendConfig,
            BackendPatternConfig,
            DTypeConfig,
            ObservationType,
            get_qnnpack_backend_config,
        )
from torch.ao.quantization.qconfig import (
            default_reuse_input_qconfig,
            default_per_channel_symmetric_qnnpack_qat_qconfig,
            QConfigAny
        )

from typing import Any, Callable, Dict, Tuple, Union, List
from tqdm import tqdm


#build model, using ResNet18 on CIFAR10 Dataset
class CIFAR10ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(CIFAR10ResNet, self).__init__()

        resnet = resnet18(pretrained=True)

        resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

        resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
        self.resnet = resnet

    def forward(self, x):
        return self.resnet(x)


#build dataset
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)


model = CIFAR10ResNet()


#build example input
dummy_input = torch.randn(1, 3, 32, 32)
for i, (image, _) in tqdm(enumerate(train_loader), total= 1):
    dummy_input = image.cuda()
    if i >= 1:
        break



model_to_quantize = copy.deepcopy(model)


#get backend, it is just the default backend config, I did not change anything
backend_config_s = get_qnnpack_backend_config()



#get qconfig_mapping, I wan to use _get_symmetric_qnnpack_qat_qconfig_mapping because I do not want quint8, I need int8
qconfig_mapping = _get_symmetric_qnnpack_qat_qconfig_mapping()



#model_prepared
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, dummy_input, backend_config = backend_config_s)


#training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_prepared = model_prepared.to(device)  
model_prepared.train()
optimizer = optim.SGD(model_prepared.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
epochs = 2
for epoch in range(epochs):
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model_prepared(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')


#convert to onnx
device = torch.device( "cpu")
model_prepared = model_prepared.to(device)  
model_quantized = quantize_fx.convert_fx(model_prepared)
dummy_input = dummy_input.to(device)
model_quantized.eval()
torch.onnx.export(model_quantized, dummy_input, "quantized_model.onnx",
    verbose = False, 
    input_names = ['input'],
    output_names = ['output'],
    )

I first ran an ONNX model following the official demo, and it seemed to work fine. However, due to ECU restrictions, my version of TensorRT can only be limited to 8.6.1, which does not support the cast to uint8 operation. Therefore, I had to try either removing the cast layer or changing it to cast to int8. My first attempt was to change it to cast to int8. In this process, I changed the qconfig_mapping to _get_symmetric_qnnpack_qat_qconfig_mapping() because the activation quantization data type called in this function’s default_symmetric_qnnpack_qconfig is int8, which allowed me to change the cast to int8. I also found that if I don’t use the qconfig_mapping = get_default_qat_qconfig_mapping(“qnnpack”) function to get the qconfig_mapping, then I need to pass in the backend in the function for preparing quantization. Here, my backend has not been customized at all; it’s just the function from the source code. After training, I obtained the ONNX model and showed some typical structures of it below.



I feel there are issues with this model.
First, the DQ layer seems to have external inputs of identity, which I understand to be zero_point and scale, but in the unmodified ONNX graph, these values are merged inside the DQ layer. LIke this (unmodified ONNX graph):

Secondly, some of the residual part branches still use uint8 instead of int8(in the red box above), resulting in the cast layer being cast to uint8, but most have been changed to cast to int8(in the green box above). I don’t understand why this is. Also, the arrangement of Q and DQ after the ADD operator is not exactly the same(you can see the difference between two structures in two yellow boxes above)
I hope someone can help me solve my issues!!! :pray: :pray: :pray: :pray: :pray: