hello,
I have Resnet classification from torchvision and I finetune it for my custom dataset and save my model(final_model.pth)
know I want to quantize this model and it quantizes without a problem and if I inference write away it is ok but if I load my quantize model and then inference it has Error
######Quantization Aware Training#######
import torch
model = torch.load("final_model.pth", map_location='cpu')
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
#model_fp32_fused = torch.quantization.fuse_modules(model,[['conv1', 'bn1', 'relu']])
model_fp32_prepared = torch.quantization.prepare_qat(model)
train_annotations_dir = val_annotations_dir = "/content/annotations.json"
train_img_dir = val_img_dir = "/content/images"
epoch = 2
batch_size = 32
image_size = 640
model_path=model_fp32_prepared
gpu_num = 0
learning_rate =0.01
opt = "adam"
momentum = 0.9
patience = 5
output = "/content/output2/"
#
m = run_classification_train(train_annotations_dir,train_img_dir,val_annotations_dir,val_img_dir,epoch,batch_size,image_size,model_path,gpu_num,learning_rate,momentum,opt,patience)
m.to("cpu")
model_int8 = torch.quantization.convert(m)
final_model_path = m
torch.save(model_int8.state_dict(), "state_qmodel.pth")
torch.save(model_int8, "fm.pth")
inference(val_annotations_dir,val_img_dir,image_size,batch_size,final_model_path,output)
if I load state_qmodel.pth I get ERROR
Error(s) in loading state_dict for ResNet:
Unexpected key(s) in state_dict: “conv1.bias”, “conv1.scale”, “conv1.zero_point”, “layer1.0.conv1.bias”, “layer1.0.conv1.scale”, “layer1.0.conv1.zero_point”, “layer1.0.conv2.bias”, “layer1.0.conv2.scale”, “layer1.0.conv2.zero_point”, “layer1.0.conv3.bias”, “layer1.0.conv3.scale”, “layer1.0.conv3.zero_point”, “layer1.0.downsample.0.bias”, “layer1.0.downsample.0.scale”, “layer1.0.downsample.0.zero_point”, “layer1.1.conv1.bias”, “layer1.1.conv1.scale”, “layer1.1.conv1.zero_point”, “layer1.1.conv2.bias”, “layer1.1.conv2.scale”, “layer1.1.conv2.zero_point”, “layer1.1.conv3.bias”, “layer1.1.conv3.scale”, “layer1.1.conv3.zero_point”, “layer1.2.conv1.bias”, “layer1.2.conv1.scale”, “layer1.2.conv1.zero_point”, “layer1.2.conv2.bias”, “layer1.2.conv2.scale”, “layer1.2.conv2.zero_point”, “layer1.2.conv3.bias”, “layer1.2.conv3.scale”, “layer1.2.conv3.zero_point”, “layer2.0.conv1.bias”, “layer2.0.conv1.scale”, “layer2.0.conv1.zero_point”, “layer2.0.conv2.bias”, “layer2.0.conv2.scale”, “layer2.0.conv2.zero_point”, “layer2.0.conv3.bias”, “layer2.0.conv3.scale”, “layer2.0.conv3.zero_point”, “layer2.0.downsample.0.bias”, “layer2.0.downsample.0.scale”, “layer2.0.downsample.0.zero_point”, “layer2.1.conv1.bias”, “layer2.1.conv1.scale”, “layer2.1.conv1.zero_point”, “layer2.1.conv2.bias”, “layer2.1.conv2.scale”, “layer2.1.conv2.zero_point”, “layer2.1.conv3.bias”, “layer2.1.conv3.scale”, “layer2.1.conv3.zero_point”, “layer2.2.conv1.bias”, “layer2.2.conv1.scale”, "layer2.2.conv1.zero_point
import torch
model = torch.load("/content/final_model.pth")
final_model_path = model.load_state_dict(torch.load('/content/state_qmodel.pth'))
inference(val_annotations_dir,val_img_dir,image_size,batch_size,final_model_path,output)