I used fbgemm
as qconfig, and I checked that my cpu (Intel Xeon silver 4114) supports AVX2 operations. But my quantized model takes 3 times longer to inference than original fp32 model.
My original model is as follows,
#Original model
class NormUnet(nn.Module):
def __init__(self, in_chans, out_chans, chans, num_pools):
super().__init__()
self.unet = Unet(
in_chans=in_chans,
out_chans=out_chans,
chans=chans,
num_pool_layers=num_pools
)
def norm(self, x):
b, h, w = x.shape
x = x.view(b, h * w)
mean = x.mean(dim=1).view(b, 1, 1)
std = x.std(dim=1).view(b, 1, 1)
x = x.view(b, h, w)
return (x - mean) / std, mean, std
def unnorm(self, x, mean, std):
return x * std + mean
def forward(self, x):
x, mean, std = self.norm(x)
x = x.unsqueeze(1)
x = self.unet(x)
x = x.squeeze(1)
x = self.unnorm(x, mean, std)
return x
As std
in norm
method doesn’t support quantization, I only quantized unet, then put it back to original model.
# Load pre-trained model.
model = NormUnet(in_chans=1, out_chans=1, chans=128, num_pools=4)
model = torch.nn.parallel.DataParallel(model)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model'])
model = model.module
model.to('cpu')
#Extract only unet part
unet_before_quantize = model.unet
#For quantization
class QuantizedUNet(nn.Module):
def __init__(self, model_fp32):
super(QuantizedUNet, self).__init__()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.model_fp32 = model_fp32
def forward(self, x):
x = self.quant(x)
x = self.model_fp32(x)
x = self.dequant(x)
return x
quantized_model = QuantizedUNet(model_fp32=unet_before_quantize)
quantization_config = torch.quantization.get_default_qconfig("fbgemm")
quantized_model.qconfig = quantization_config
torch.quantization.prepare(quantized_model, inplace=True)
Then I calibrate and quantize. After all these, Quantized model output is what I expected, but inference time gets about 3~4 times longer.
#I calibrated in between
quantized_model = torch.quantization.convert(quantized_model, inplace=True)
quantized_model.eval()
#I replace only unet block
model.unet = quantized_model
Where did I make mistakes?
I tried torch.set_num_threads(1)
, and export to torchscript then load back. Couldn’t get faster than original fp32 model.
Helps will be really appreciated.
Thanks!