PT2E quantization doesn't reduce the model size

I’m trying to quantize a simple model using PT2E quantization workflow and reproduced the code from the Quick Start Guide. However, the sizes of the original floating point model and the quantized model are the same. What am I missing here?

import os
import torch
from torch.export import export
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
    X86InductorQuantizer,
    get_default_x86_inductor_quantization_config,
)

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(1024, 1024)
        self.linear2 = torch.nn.Linear(1024, 1024)

    def forward(self, x):
        x = self.linear1(x)
        out = self.linear2(x)

        return out


float_model = M().eval()

example_inputs = (torch.randn(128, 1024),)

exported_model = export(float_model, example_inputs).module()

quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())

prepared_model = prepare_pt2e(exported_model, quantizer)
prepared_model(*example_inputs)
quantized_model = convert_pt2e(prepared_model)

torch.save(float_model.state_dict(), 'float_model.pt')
torch.save(quantized_model.state_dict(), 'quant_model.pt')
float_model_size_mb = os.path.getsize('float_model.pt') / 1024 / 1024
quant_model_size_mb = os.path.getsize('quant_model.pt') / 1024 / 1024

print(f'Float model size: {float_model_size_mb:.2f} MiB')
print(f'Quantized model size: {quant_model_size_mb:.2f} MiB')

The output on Google Colab is

Float model size: 8.01 MiB
Quantized model size: 8.01 MiB

I made some progress here and wanted to document it. It appears that the model should be exported again after the quantization, which Quick Start Guide fails to mention. So this code is working on my laptop (PyTorch Version: 2.7.1+cu126, Torchao Version: 0.14.1):

import os
import torch
from torch.export import export
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( 
    X86InductorQuantizer,
    get_default_x86_inductor_quantization_config,
)

class M(torch.nn.Module):
    def _init_(self):
        super()._init_()
        self.linear1 = torch.nn.Linear(1024, 1024)
        self.linear2 = torch.nn.Linear(1024, 1024)

def forward(self, x):
    x = self.linear1(x)
    out = self.linear2(x)

    return out

float_model = M().eval()

example_inputs = (torch.randn(128, 1024),)

exported_model = export(float_model, example_inputs).module()

quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())

prepared_model = prepare_pt2e(exported_model, quantizer)
prepared_model(*example_inputs)
converted_model = convert_pt2e(prepared_model)

# This step was missing
quantized_model = torch.export.export(converted_model, example_inputs).module()

torch.save(float_model.state_dict(), ‘/tmp/float_model.pt’)
torch.save(converted_model.state_dict(), ‘/tmp/conv_model.pt’)
torch.save(quantized_model.state_dict(), ‘/tmp/quant_model.pt’)

float_model_size_mb = os.path.getsize(‘/tmp/float_model.pt’) / 1024 / 1024
conv_model_size_mb = os.path.getsize(‘/tmp/conv_model.pt’) / 1024 / 1024
quant_model_size_mb = os.path.getsize(‘/tmp/quant_model.pt’) / 1024 / 1024

print(f’Float model size: {float_model_size_mb:.2f} MiB’)
print(f’Converted model size: {conv_model_size_mb:.2f} MiB’)
print(f’Quantized model size: {quant_model_size_mb:.2f} MiB’)

The output is:

Float model size: 8.01 MiB
Converted model size: 10.04 MiB
Quantized model size: 2.03 MiB

Curiously, it still doesn’t work in Google Colab (PyTorch Version: 2.8.0+cu126, Torchao Version: 0.14.1), returning 8.01 MiB for all models.

1 Like

model after pt2e is using q/dq representation: PyTorch 2 Export Post Training Quantization — torchao main documentation but we do fold quantize ops by default: ao/torchao/quantization/pt2e/quantize_pt2e.py at bcd5dbc0cedcd2a330bb1b55bdbfb8625273cf23 · pytorch/ao · GitHub

so it could be some regression on the export side