How to quantize Mobilenet_V2 after training

I trained mobilenet v2 with output is 16 classes and save model. After I load model and quantize that model, then get error when predict. Can you help me fix this bug. thanks a lot

Code load model after training

def load_model(model_path, num_classes=16, device=torch.device('cpu')):
    model = models.mobilenet_v2()
    n_input_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(n_input_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval().to(device)
    return model

Code quantize model

model_fp32 = load_model("mobilenetv2_16_classes.pt").to("cpu")
model_fp32.eval()
print(model_fp32)
backend = "fbgemm"
model_fp32.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(model_fp32, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)

Quantize success

def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

print_model_size(model_fp32) 
print_model_size(model_static_quantized)
9.6MB
2.7MB

Then I run and get Error

model_static_quantized(norm_img)
return ops.quantized.conv2d(
NotImplementedError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].```

The error message means that the input needs to be quantized.
The easiest way to achieve this might be to wrap the model into a larger model and add torch.quantization.QuantStub/DeQuantStub layers before calling into the model.

The quantization tutorial has these steps, though they are easy to miss.

Best regards

Thomas

Thanks @tom . But I don’t know how to quantize my model Mobilenetv2 finetuned with output is 16 classes. Can you help me!

Something like


import torch
import torchvision
model = torchvision.models.mobilenet_v2()
# load you weights here into model
qmodel = torch.quantization.QuantWrapper(model)

qmodel.eval()

backend = "fbgemm"
qmodel.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(qmodel, inplace=False)

# feed real inputs here
model_static_quantized(torch.randn(1, 3, 224, 224))

model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)

model_static_quantized

seems to work for me.

Best regards

Thomas