RuntimeError: Unsupported qscheme: per_channel_affine during fx qat

I use default qat qconfig for my model:
qconfig_mapping = get_default_qat_qconfig_mapping("x86")
after I prepare model with the config
model.backbone = prepare_qat_fx(model.backbone, qconfig_mapping, example_inputs)
I see that for my first layer per_channel_symmetric is going to be applied.

>> dict(model.backbone.named_modules())['backbone.conv_stem'].qconfig.weight.p.keywords['qscheme']
>> torch.per_channel_symmetric

and then after training for one epoch I call convert_fx(model.backbone) and get the mentioned error.
so there are actually two questions:

  1. I guess that during training the weight qscheme were changed for some reason but why?
  2. why is it not able to apply the qscheme (per_channel_affine) especially considering that the scheme were changed automatically (I guess).

this is still per_channel_symmetric, it’s just in quantized tensors we don’t use it and convert it to per_channel_affine to simplify the kernel support (so you don’t have operation between per_channel_affine and per_channel_symmetric tensors etc.)

I see.
But what about the main point? Why do we have the error for simple conv layer?

sorry, not exactly sure, could you give me a minimal repro?

sure, here it is

from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
import timm
from tqdm import tqdm
import torch
from torch import nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=3).to(device)

qconfig_mapping = get_default_qat_qconfig_mapping("x86")
example_inputs = torch.rand(size=(1,3,128,128)).to(device)

model = prepare_qat_fx(model, qconfig_mapping, example_inputs)

# train
def fake_train(model, n_samples=5, bs=64, device='cpu'):
    optimizer = torch.optim.AdamW(model.parameters(), lr=10e-4)
    loss = nn.CrossEntropyLoss()
    
    model.train()
    for _ in tqdm(range(n_samples)):
        x = torch.rand(size=(bs,3,128,128)).to(device)
        y = torch.empty(bs, dtype=torch.long).random_(3).to(device)

        optimizer.zero_grad()
        y_hat = model(x)
        output = loss(y_hat, y)
        output.backward()
        optimizer.step()

fake_train(model, device=device)

quantized_model = convert_fx(model)



!pip freeze | grep 'torch'


pytorch-lightning==2.0.3
pytorch-triton==2.1.0+6e4932cda8
torch==2.0.1+cu118
torchaudio==2.0.2+cu118
torchmetrics==0.11.4
torchvision==0.15.2+cu118

@jerryzh168, any insights?

sorry for the late reply, this looks pretty weird, I assume it comes from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp but when I search for “Unsupported”, it looks like per_channel_affine should be captured by other if/else branches

OK I think it might come from here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp#L37

does it work if you move the model to cpu?

yes, it does. thanks for that but maybe you can have some ideas about the further problem.
I have custom architecture (object detection) and passing it completely to prepare_fx throws Exceptions due to non-traceable nature some of the operations so I decided to quantize only backbone (classic timm feature extractor) and the only way I found how achieve this is like that:
model.backbone = prepare_fx(model.backbone)

after every training epoch I do validation on converted model:

        quantized_model = copy.deepcopy(model).to('cpu')
        quantized_backbone = convert_fx(quantized_model.backbone)
        quantized_model.backbone = quantized_backbone

and the performance of the model seems to be very low.
So maybe you have some ideas about the point such as:

  • maybe there’s more preferred way of ignoring submodules from quantization (I tried setting None as qconfig for different modules but that doesn’t work because anyway the whole module that is passed to prepare_fx is going to be traced);
  • maybe I need to do smth additionally along with my direct backbone substitution;

and one more not directly related question:
is there only int8 quantization precision mode available?

1 Like