Significant Slowdown in Inference Speed with Quantized Model in PyTorch 2.1 pt2e

I’ve recently encountered an issue with PyTorch 2.1 where the inference speed of a quantized model is significantly slower than its FP32 counterpart (running on CUDA). The quantized model’s inference is over 10 times slower.

Here’s the code snippet that reproduces this behavior:

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
import torch
import torchvision
import time

def loop(model, inputs):
    # warmup
    with torch.no_grad():
        for i in range(20):
            model(inputs)

    s = time.time()
    res = []
    with torch.no_grad():
        for i in range(20):
            s = time.time() 
            model(inputs)
            res.append(time.time() - s)
    return res

   
if __name__ == '__main__':
    inputs = torch.randn(32, 3, 224, 224).cuda()
    model = torchvision.models.vgg16().eval().cuda()
    res_fp32 = loop(model, inputs)
    print(res_fp32)

    model = torch._export.capture_pre_autograd_graph(model, [inputs])
    quantizer = XNNPACKQuantizer().set_global(
        get_symmetric_quantization_config(is_per_channel=True)
    )
    model = prepare_pt2e(model, quantizer)
    with torch.no_grad():
        model(inputs)
    model = convert_pt2e(model)

    res_quant = loop(model, inputs)
    print(res_quant)

The result on my local machine:

[..., 0.055181026458740234, 0.055129289627075195]
[..., 1.3574013710021973, 1.354731559753418]

Here are some profiler results:

fp32 model

quantized model

The performance running on the current Colab is similar. Link to Colab

I’m seeking insights into any potential solutions or optimizations that I might be missing.

My goal is to evaluate the performance of the quantized model after PTQ.

Any advice or guidance would be greatly appreciated. Thank you.

The bottleneck is per-channel quantization. For details, refer to here.

so to summarize jerry’s comment in the issue, what you have is kind of a fake quantized model that would then be lowered to the actual backend in question.

see (prototype) PyTorch 2 Export Post Training Quantization — PyTorch Tutorials 2.2.0+cu121 documentation

The model produced at this point is not the final model that runs on the device, it is a reference quantized model that captures the intended quantized computation from the user, expressed as ATen operators and some additional quantize/dequantize operators, to get a model that runs on real devices, we’ll need to lower the model. For example, for the models that run on edge devices, we can lower with delegation and ExecuTorch runtime operators.

Further, my understanding is that we (only kind of) have kernels for quantized ops on cuda at the moment so it doesn’t look like this method will work.

There is another option, as described in the recent blog posts (1,2,3) Where you can apply quantization using the api’s in the torchao repository. HOWEVER it currently only works for linear ops, not convolutions so i doubt you’ll see a ton of improvement for VGG which is mainly convolutional. It works great for transformer models though.

import torch
from torchao.quantization import quant_api

# some user model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_dqtensors(model)

# compile the model to improve performance
model_c = torch.compile(model, mode='max-autotune')
model_c(input)
2 Likes

Thank you so much for your informative reply!

I have some more questions,

  1. I’m still a little confused, does this lowering already work now? I didn’t find tutorials about it, for example how to lowering to a cuda device.
  2. I’m kinda curious about the relationship between torchao and pt2e quantization. Does torchao seem more likely to exist as a standalone project?
  1. No it doesn’t.

  2. they’re both made by our team (pytorch ao). Pt2e does a lot of things and needs graph information for some of its features, the torchao stuff does not. They may both be folded together at some point in the future. They’re both under active development. Torchao will continue to exist or be moved to core pytorch in the future. For your use case torchao is ready now, pt2e has had more development oriented towards CPU/edge and less focus on cuda.

1 Like

Gotcha! Thank you so much!!!