I’m trying to perform QAT utilizing MobileNetV2 with the goal of converting it into TFLite. However, after training the model, I run a bench-marking script to compare its performance to the original model and see that the performance deprecates greatly.
Here are the important code snippets:
from torchvision import models
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config
model = models.mobilenet_v2(weights='DEFAULT')
example_inputs = (next(iter(dataloader))[0].to(device),)
model = torch.export.export_for_training(model, example_inputs).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
model = prepare_qat_pt2e(model, quantizer)
train_model(model)
I only included what I thought was relevant since I didn’t want to add confusion with all of my helper functions
def train_model(model):
for phase in ['train', 'val']:
is_train = phase == 'train'
if is_train:
torch.ao.quantization.move_exported_model_to_train(model)
else:
# Switch to evaluation mode to perform inference
torch.ao.quantization.move_exported_model_to_eval(model)
data_loader = train_loader if is_train else val_loader
running_loss = 0.0
total_samples = 0.0
predictions, ground_truths, probabilities = [], [], []
with tqdm(total=len(data_loader), desc=f"{phase.capitalize()} Epoch {epoch + 1}/{epochs}") as pbar:
for inputs, labels in data_loader:
inputs, labels = inputs.to(device), labels.to(device)
# Zero gradients only during training
if is_train:
optimizer.zero_grad()
# Enable gradients only in training phase
with torch.set_grad_enabled(is_train):
model = model.to(device)
model_logits = model(inputs)
soft_loss = compute_distillation_loss(model_logits)
label_loss, probs, preds = compute_loss_and_predictions(model_logits, labels, criterion)
# Compute weighted combination of the distillation and cross entropy losses
loss = soft_target_loss_weight * soft_loss + ce_loss_weight * label_loss
# Backward pass and optimizer step in training phase
if is_train:
loss.backward()
optimizer.step()
# Update progress bar with average loss so far
pbar.set_postfix(loss=f"{running_loss / total_samples:.4f}")
pbar.update(1)
quantized_model = convert_pt2e(model, fold_quantize=False)
Actual vs expected behavior:
I would expect that the quantized model has better performance than the original model but it does not.
Original | QAT | |
---|---|---|
Model Size (MB) | 9.1899 | 11.1504 |
Inference Time (sec/sample) | 0.002896 | 0.011141 |
Throughput (samples/sec) | 345.29 | 89.76 |
Energy per Sample (Joules) | 0.3436 | 1.350853 |
Throughput per Watt (samples/sec/W) | 2.91 | 0.74 |
This is even stranger since if I switch to FX Graph QAT, I get the expected behavior. However, I need to use Export quantization since I want to use the ai-edge-torch API to convert my model to TFLite.
Original | QAT | |
---|---|---|
Model Size (MB) | 9.1899 | 2.3465 |
Inference Time (sec/sample) | 0.002896 | 000250 |
Throughput (samples/sec) | 345.29 | 4003.28 |
Energy per Sample (Joules) | 0.3436 | 0.0271 |
Throughput per Watt (samples/sec/W) | 2.91 | 36.85 |
Additionally, when I print the resulting QAT model I get the following:
GraphModule(
(features): Module(
(0): Module(
(1): Module()
)
(1): Module(
(conv): Module(
(0): Module(
(1): Module()
)
(2): Module()
)
)
(2): Module(
(conv): Module(
(0): Module(
(1): Module()
)
(1): Module(
(1): Module()
)
(3): Module()
)
)
(3): Module(
...
I would think that it would be more similar to the resulting QAT model from FX Graph quantization which leads me to believe that it is not training correctly. The FX Graph is added below:
GraphModule(
(features): Module(
(0): Module(
(0): QuantizedConv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.22475136816501617, zero_point=113, padding=(1, 1))
(2): ReLU6(inplace=True)
)
(1): Module(
(conv): Module(
(0): Module(
(0): QuantizedConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.36381739377975464, zero_point=112, padding=(1, 1), groups=32)
(2): ReLU6(inplace=True)
)
(1): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.5194709300994873, zero_point=139)
)
)
...