Inference accuracy mismatch between original, quantized, dequantized model

I’m trying to do quantize-aware training on a VGG11 model for CIFAR10 dataset and finally want to do a dequantization on the quantized int8 weights to see the accuracy impact. The training should be successful and both the original fp32 model and the quantized model give reasonable inference accuracy, which is approximately 90%. However, when I was trying to do manual dequantization on the quantized model, the inference accuracy of the dequantized model dropped down to about 10%. Since I’m using the same int8 weight and the q_scale and zero_points during dequantization, I used to believe that the inference accuracy should be similar to the quantized model. Can anyone tell me what could be the potential reason?

Here is my code: (quantized_model and original_fp_model gives 90% accuracy and the other two gives 10% accuracy)
import torch
import torchvision
import torch.quantization as quantization
import torchvision.models as models
import torch.nn as nn
import torchvision.transforms as transforms

class QATVGG11(nn.Module):
def init(self, feature_output_size):
super(QATVGG11, self).init()
self.features = models.vgg11(pretrained=True).features
self.classifier = nn.Sequential(
nn.Linear(feature_output_size, 4096),
nn.ReLU(inplace=True),
#nn.Dropout(p=0.5),
nn.Linear(4096, 10))
self.quant = quantization.QuantStub()
self.dequant = quantization.DeQuantStub()

def forward(self, x):
    x = self.quant(x)
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    x = self.dequant(x)
    return x

def load_quantized_model(model_save_path):
# Recreate the model architecture
model = QATVGG11(feature_output_size=25088) # Provide the correct feature_output_size used in training
model.to(‘cpu’) # Load to CPU for simplicity
model.qconfig = quantization.get_default_qat_qconfig(‘fbgemm’)
model = quantization.prepare_qat(model, inplace=False)
model.load_state_dict(torch.load(model_save_path))
model.eval() # Set the model to evaluation mode
quantized_model = quantization.convert(model, inplace=False)
return quantized_model

Path to the saved qat model

model_save_path = ‘vgg11_trained_fp32_nodropout.pth’
quantized_model = load_quantized_model(model_save_path)

def dequantize_modified_tensor(quantized_tensor, update_mode):
if not quantized_tensor.is_quantized:
raise ValueError(“The tensor is not quantized.”)

scales = quantized_tensor.q_per_channel_scales()
zero_points = quantized_tensor.q_per_channel_zero_points()
int_repr = quantized_tensor.int_repr()

scales = scales.view(-1, 1, 1, 1)  # Shape for broadcasting
zero_points = zero_points.view(-1, 1, 1, 1)  # Shape for broadcasting

float_tensor_orig = (int_repr - zero_points) * scales
if update_mode == "origin":
    float_tensor = (int_repr - zero_points) * scales
elif update_mode == "dequantize":
    float_tensor = quantized_tensor.dequantize()

return float_tensor

def replace_weights_with_dequantized(model, update_mode):
state_dict = model.state_dict()
model_fp32 = QATVGG11(feature_output_size=25088)
state_dict_fp32 = model_fp32.state_dict()
for name, param in state_dict.items():
if ‘weight’ in name and isinstance(param, torch.Tensor):
quantized_tensor = param
# Dequantize the tensor
float_tensor = dequantize_modified_tensor(quantized_tensor, update_mode)
# Replace the model’s weight with the dequantized tensor
state_dict_fp32[name].copy_(float_tensor)
model_fp32.load_state_dict(state_dict_fp32)
return model_fp32

new_model_origin = replace_weights_with_dequantized(quantized_model, ‘origin’)
new_model_origin_deq = replace_weights_with_dequantized(quantized_model, ‘dequantize’)

def evaluate_model(model, data_loader, device):
model.to(device)
model.eval()

state_dict = model.state_dict()
for name, param in state_dict.items():
correct = 0
total = 0
for data, target in test_loader:
    data, target = data.to(device), target.to(device)
    output = model(data)
    _, predicted = torch.max(output, 1)
    total += target.size(0)
    correct += (predicted == target).sum().item()
print(f'Accuracy of the model on the test set: {100 * correct / total:.2f}%')

original_fp_model = QATVGG11(feature_output_size=25088)
original_fp_model.qconfig = quantization.get_default_qat_qconfig(‘fbgemm’)
original_fp_model = quantization.prepare_qat(original_fp_model, inplace=False)
original_fp_model.load_state_dict(torch.load(model_save_path))

print(“Original model:”)
evaluate_model(quantized_model, test_loader, ‘cpu’)
print(“Original fp model:”)
evaluate_model(original_fp_model, test_loader, ‘cpu’)
print(“Original model after dequantization:”)
evaluate_model(new_model_origin, test_loader, ‘cpu’)
print(“Original model after pytorch dequantization:”)
evaluate_model(new_model_origin_deq, test_loader, ‘cpu’)