Huge accuracy drop from QAT model after convert_pt2e

I’m trying to quantize a CNN using QAT and after running convert_pt2e the F1 ratio drops from around 90% down to 70%. Here’s my code, not sure if doing something wrong or how to debug here.

model_for_export = get_model(
    model_name=model_name,
    n_classes=n_classes,
    input_size=input_size,
    **model_kwargs,
)

quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
    pt2e_quantizer.get_symmetric_quantization_config()
)

model_for_export.eval()
sample_inputs = (torch.randn(1, input_size, input_size, 3),)
pt2e_export = export_for_training(model_for_export, sample_inputs).module()
pt2e_export = quantize_pt2e.prepare_qat_pt2e(pt2e_export, quantizer)

if qat_best_path.exists():
    print("Loading best QAT weights before conversion...")
    pt2e_export.load_state_dict(torch.load(qat_best_path, weights_only=True))

pt2e_export = pt2e_export.to("cuda")
pt2e_export = move_exported_model_to_eval(pt2e_export)

At this point the pt2e_export model has an F1 ratio of around 90%, all looks good. Next I do

test_model = quantize_pt2e.convert_pt2e(pt2e_export, fold_quantize=False)
test_model = move_exported_model_to_eval(test_model)

This test_model now has an F1 ratio of around 70%.

Running some forward passes here’s what I see

pt2e_export(torch.zeros(1, 128, 128, 3).cuda())
tensor([[-13.6758, -14.1442, -11.1460, -22.0603, -16.6179,  -7.9119]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
test_model(torch.zeros(1, 128, 128, 3).cuda())
tensor([[-12.3561, -14.5184, -10.1938, -19.1520, -15.7540, -10.1938]],
       device='cuda:0')

Even passing in zeros the numbers are pretty far off. For context the model is MobileNetV3. I’m not sure what’s going on here at all. Please help.

Unrelated but this guide does not work anymore PyTorch 2 Export Quantization-Aware Training (QAT) — torchao 0.13 documentation