QAT model is not performing as expected when compared to the original model

I’m trying to perform QAT utilizing MobileNetV2 with the goal of converting it into TFLite. However, after training the model, I run a bench-marking script to compare its performance to the original model and see that the performance deprecates greatly.

Here are the important code snippets:

from torchvision import models
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config


model = models.mobilenet_v2(weights='DEFAULT')

example_inputs = (next(iter(dataloader))[0].to(device),)
model = torch.export.export_for_training(model, example_inputs).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))

model = prepare_qat_pt2e(model, quantizer)

train_model(model)

I only included what I thought was relevant since I didn’t want to add confusion with all of my helper functions

def train_model(model):

  for phase in ['train', 'val']:
            is_train = phase == 'train'

            if is_train:
                torch.ao.quantization.move_exported_model_to_train(model)
            else:
                # Switch to evaluation mode to perform inference
                torch.ao.quantization.move_exported_model_to_eval(model)

            data_loader = train_loader if is_train else val_loader

            running_loss = 0.0
            total_samples  = 0.0
            predictions, ground_truths, probabilities = [], [], []

            with tqdm(total=len(data_loader), desc=f"{phase.capitalize()} Epoch {epoch + 1}/{epochs}") as pbar:
                for inputs, labels in data_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    # Zero gradients only during training
                    if is_train:
                        optimizer.zero_grad()
               
                    # Enable gradients only in training phase
                    with torch.set_grad_enabled(is_train):                        

                        model = model.to(device)
                        model_logits = model(inputs)

                        soft_loss = compute_distillation_loss(model_logits)

                        label_loss, probs, preds = compute_loss_and_predictions(model_logits, labels, criterion)
                        
                        # Compute weighted combination of the distillation and cross entropy losses
                        loss = soft_target_loss_weight * soft_loss + ce_loss_weight * label_loss

                        # Backward pass and optimizer step in training phase
                        if is_train:
                            loss.backward()
                            optimizer.step()
                    
                    # Update progress bar with average loss so far
                    pbar.set_postfix(loss=f"{running_loss / total_samples:.4f}")
                    pbar.update(1)

quantized_model = convert_pt2e(model, fold_quantize=False)

Actual vs expected behavior:

I would expect that the quantized model has better performance than the original model but it does not.

Original QAT
Model Size (MB) 9.1899 11.1504
Inference Time (sec/sample) 0.002896 0.011141
Throughput (samples/sec) 345.29 89.76
Energy per Sample (Joules) 0.3436 1.350853
Throughput per Watt (samples/sec/W) 2.91 0.74

This is even stranger since if I switch to FX Graph QAT, I get the expected behavior. However, I need to use Export quantization since I want to use the ai-edge-torch API to convert my model to TFLite.

Original QAT
Model Size (MB) 9.1899 2.3465
Inference Time (sec/sample) 0.002896 000250
Throughput (samples/sec) 345.29 4003.28
Energy per Sample (Joules) 0.3436 0.0271
Throughput per Watt (samples/sec/W) 2.91 36.85

Additionally, when I print the resulting QAT model I get the following:

GraphModule(
(features): Module(
(0): Module(
  (1): Module()
)
(1): Module(
  (conv): Module(
    (0): Module(
      (1): Module()
    )
    (2): Module()
  )
)
(2): Module(
  (conv): Module(
    (0): Module(
      (1): Module()
    )
    (1): Module(
      (1): Module()
    )
    (3): Module()
  )
)
(3): Module(
...

I would think that it would be more similar to the resulting QAT model from FX Graph quantization which leads me to believe that it is not training correctly. The FX Graph is added below:

GraphModule(
  (features): Module(
    (0): Module(
      (0): QuantizedConv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.22475136816501617, zero_point=113, padding=(1, 1))
      (2): ReLU6(inplace=True)
    )
    (1): Module(
      (conv): Module(
        (0): Module(
          (0): QuantizedConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.36381739377975464, zero_point=112, padding=(1, 1), groups=32)
          (2): ReLU6(inplace=True)
        )
        (1): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.5194709300994873, zero_point=139)
      )
    )
...