Hi! I have quantized object detection model using static quantization strategy. The model got quantized and works fine. The problem I am facing is that, the very first two forward passes take way longer time than the following passes after that. By “longer” I mean a couple of seconds, where as after the first two passes the model execution takes about ~5ms on average. I have followed the official documentations in implementing the training script but it seems to me there is a bug somewhere that causes the problem. Could you please help me out with this? Below I have attached the model and the training script. Thanks a lot!
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# return nn.functional.relu(x, inplace=True)
# study the model in more depth
class FaceBoxesMobNet(nn.Module):
def __init__(self, phase, num_classes=1, quantization=False):
super(FaceBoxesMobNet, self).__init__()
self.phase = phase
self.num_classes = num_classes
self.backbone = MobileNetV2()
# self.backbone = MobileNetV3()
# why don't we just use feature maps from the backbone
self.avg = nn.Conv2d(1280, 128, kernel_size=1, bias=False, groups=128)
# self.avg = nn.Conv2d(96, 128, kernel_size=1, bias=False)
self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.get_multibox(self.num_classes)
self.quantization = quantization
self.quant = QuantStub()
self.dequant = DeQuantStub()
# if self.phase == 'test':
self.prob = nn.Sigmoid()
if self.phase == 'train':
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m.bias is not None:
nn.init.xavier_normal_(m.weight.data)
m.bias.data.fill_(0.02)
else:
m.weight.data.normal_(0, 0.01)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def get_multibox(self, num_classes):
self.loc0 = nn.Conv2d(128, 3 * 4, kernel_size=3, padding=1)
self.conf0 = nn.Conv2d(128, 3 * num_classes, kernel_size=3, padding=1)
self.loc1 = nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)
self.conf1 = nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)
self.loc2 = nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)
self.conf2 = nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)
# self.loc0 = nn.Conv2d(128, 2 * 4, kernel_size=3, padding=1)
# self.conf0 = nn.Conv2d(128, 2 * num_classes, kernel_size=3, padding=1)
# self.loc1 = nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)
# self.conf1 = nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)
# self.loc2 = nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)
# self.conf2 = nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)
def fuse_model(self):
for m in self.modules():
if type(m) == ConvBNReLU:
torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == BasicConv2d:
torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)
if type(m) == InvertedResidual:
for idx in range(len(m.conv)):
if type(m.conv[idx]) == nn.Conv2d:
torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
def forward(self, x):
# here we are doing static quantazation
x = self.quant(x) # <-- where tensors are quatized
x = self.backbone(x)
# print('output shape of backbne: {}'.format(x.shape))
x = self.avg(x)
loc0 = self.loc0(x).permute(0, 2, 3, 1).contiguous()
loc0 = loc0.view(loc0.size(0), -1)
conf0 = self.conf0(x).permute(0, 2, 3, 1).contiguous()
conf0 = conf0.view(conf0.size(0), -1)
x = self.conv3_1(x)
x = self.conv3_2(x)
loc1 = self.loc1(x).permute(0, 2, 3, 1).contiguous()
loc1 = loc1.view(loc1.size(0), -1)
conf1 = self.conf1(x).permute(0, 2, 3, 1).contiguous()
conf1 = conf1.view(conf1.size(0), -1)
x = self.conv4_1(x)
x = self.conv4_2(x)
loc2 = self.loc2(x).permute(0, 2, 3, 1).contiguous()
loc2 = loc2.view(loc1.size(0), -1)
conf2 = self.conf2(x).permute(0, 2, 3, 1).contiguous()
conf2 = conf2.view(conf2.size(0), -1)
loc0 = self.dequant(loc0)
conf0 = self.dequant(conf0)
loc1 = self.dequant(loc1)
conf1 = self.dequant(conf1)
loc2 = self.dequant(loc2)
conf2 = self.dequant(conf2) # <-- here the tensors are converted back to floating point percisionj
loc = torch.cat([loc0, loc1, loc2], dim=1)
conf = torch.cat([conf0, conf1, conf2], dim=1)
if self.phase == "test":
output = (loc.view(loc.size(0), -1, 4),
self.prob(conf.view(-1, self.num_classes)),
None)
else:
output = (loc.view(loc.size(0), -1, 4),
conf.view(conf.size(0), -1, self.num_classes),
None)
return output
def run_static_quantization(mode='per_tensor'):
dataset_path = 'path/to/data/'
device = 'cpu'
model = FaceBoxesMobNet('test')
model = model.to(device)
model.eval()
# model.load_state_dict(remove_prefix(torch.load(os.path.join(args.save_folder,
# f'{experiment_name()}/{args.quantized_model}'),
# map_location=device), 'module.'))
model.load_state_dict(torch.load(os.path.join(args.save_folder,
f'{experiment_name()}/{args.quantized_model}'),
map_location=device))
print("Loaded the model successfully. Start measuring the speed")
s = run_speed_test(net=model, device=device)
print(f"Speed of model before quantization: {s} ms/frame")
print("Size of model before quantization")
print_size_of_model(model)
# Fuse Conv, bn and relu
model.fuse_model()
# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
model.qconfig = torch.quantization.default_qconfig if mode == 'per_channel' \
else torch.quantization.get_default_qconfig('fbgemm')
print(model.qconfig)
torch.quantization.prepare(model, inplace=True)
# Calibrate with the val set
run_evaluation(net=model, device=device, dataset_path=dataset_path)
print(f'Post Training {mode} Quantization: Calibration done')
# Convert to quantized model
torch.quantization.convert(model, inplace=True)
print('Post Training Quantization: Convert done')
# print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',
# model.features[1].conv)
print(f"Size of model after {mode} quantization")
print_size_of_model(model)
print("Performance of model after quantization")
run_evaluation(net=model, device='cpu', dataset_path=dataset_path)
s = run_speed_test(net=model, device=device)
print(f"Speed of model after {mode} quantization: {s} ms/frame")
torch.jit.save(torch.jit.script(model),
os.path.join(args.save_folder, f'{experiment_name()}/static_final{mode}.pth'))