import torch
import copy
from torch.quantization import get_default_qconfig
qmodel = copy.deepcopy(model)
from torch.quantization.quantize_fx import prepare_fx, convert_fx
qmodel.eval()
qconfig = get_default_qconfig("qnnpack")
qconfig_dict = {"": qconfig}
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
prepared_model = prepare_fx(qmodel, qconfig_dict) # fuse modules and insert observers
calibrate(prepared_model, val_loader) # run calibration on sample data
quantized_model = convert_fx(prepared_model, is_reference=True)
the convert_fx function returns the following error:
AssertionError: Floating point module class <class ‘torch.nn.quantized._reference.modules.conv.Conv2d’> does not have a corresponding quantized module class