How to load quantized model for inference

hello,
I have Resnet classification from torchvision and I finetune it for my custom dataset and save my model(final_model.pth)
know I want to quantize this model and it quantizes without a problem and if I inference write away it is ok but if I load my quantize model and then inference it has Error

######Quantization Aware Training#######
import torch
model = torch.load("final_model.pth", map_location='cpu')
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
#model_fp32_fused = torch.quantization.fuse_modules(model,[['conv1', 'bn1', 'relu']])
model_fp32_prepared = torch.quantization.prepare_qat(model)
train_annotations_dir = val_annotations_dir = "/content/annotations.json"
train_img_dir = val_img_dir = "/content/images"
epoch = 2
batch_size = 32
image_size = 640
model_path=model_fp32_prepared
gpu_num = 0
learning_rate =0.01
opt = "adam"
momentum = 0.9
patience = 5
output = "/content/output2/"
#
m = run_classification_train(train_annotations_dir,train_img_dir,val_annotations_dir,val_img_dir,epoch,batch_size,image_size,model_path,gpu_num,learning_rate,momentum,opt,patience)
m.to("cpu")
model_int8 = torch.quantization.convert(m)
final_model_path = m
torch.save(model_int8.state_dict(), "state_qmodel.pth")
torch.save(model_int8, "fm.pth")
inference(val_annotations_dir,val_img_dir,image_size,batch_size,final_model_path,output)

if I load state_qmodel.pth I get ERROR

Error(s) in loading state_dict for ResNet:
Unexpected key(s) in state_dict: “conv1.bias”, “conv1.scale”, “conv1.zero_point”, “layer1.0.conv1.bias”, “layer1.0.conv1.scale”, “layer1.0.conv1.zero_point”, “layer1.0.conv2.bias”, “layer1.0.conv2.scale”, “layer1.0.conv2.zero_point”, “layer1.0.conv3.bias”, “layer1.0.conv3.scale”, “layer1.0.conv3.zero_point”, “layer1.0.downsample.0.bias”, “layer1.0.downsample.0.scale”, “layer1.0.downsample.0.zero_point”, “layer1.1.conv1.bias”, “layer1.1.conv1.scale”, “layer1.1.conv1.zero_point”, “layer1.1.conv2.bias”, “layer1.1.conv2.scale”, “layer1.1.conv2.zero_point”, “layer1.1.conv3.bias”, “layer1.1.conv3.scale”, “layer1.1.conv3.zero_point”, “layer1.2.conv1.bias”, “layer1.2.conv1.scale”, “layer1.2.conv1.zero_point”, “layer1.2.conv2.bias”, “layer1.2.conv2.scale”, “layer1.2.conv2.zero_point”, “layer1.2.conv3.bias”, “layer1.2.conv3.scale”, “layer1.2.conv3.zero_point”, “layer2.0.conv1.bias”, “layer2.0.conv1.scale”, “layer2.0.conv1.zero_point”, “layer2.0.conv2.bias”, “layer2.0.conv2.scale”, “layer2.0.conv2.zero_point”, “layer2.0.conv3.bias”, “layer2.0.conv3.scale”, “layer2.0.conv3.zero_point”, “layer2.0.downsample.0.bias”, “layer2.0.downsample.0.scale”, “layer2.0.downsample.0.zero_point”, “layer2.1.conv1.bias”, “layer2.1.conv1.scale”, “layer2.1.conv1.zero_point”, “layer2.1.conv2.bias”, “layer2.1.conv2.scale”, “layer2.1.conv2.zero_point”, “layer2.1.conv3.bias”, “layer2.1.conv3.scale”, “layer2.1.conv3.zero_point”, “layer2.2.conv1.bias”, “layer2.2.conv1.scale”, "layer2.2.conv1.zero_point

import torch
model = torch.load("/content/final_model.pth")
final_model_path = model.load_state_dict(torch.load('/content/state_qmodel.pth'))
inference(val_annotations_dir,val_img_dir,image_size,batch_size,final_model_path,output)

Hi @m.safari, when you run the quantization APIs it changes the state dict, because quantized layers can have different fields compared to their floating point counterparts. Therefore, when you load a quantized checkpoint, the recommendation is to create the fp32 architecture, run the quantization APIs (on random weights), and then load the quantized state dict. In your example, it would be something like

# create fp32 model
model = torch.load("/content/final_model.pth")

# quantize it without calibration (weights will not be final)
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
#model_fp32_fused = torch.quantization.fuse_modules(model,[['conv1', 'bn1', 'relu']])
model_fp32_prepared = torch.quantization.prepare_qat(model)
model_int8 = torch.quantization.convert(m)

# load the real state dict
model_int8.load_state_dict(torch.load('/content/state_qmodel.pth')
2 Likes