First forward pass in quantized model takes long time

Hi! I have quantized object detection model using static quantization strategy. The model got quantized and works fine. The problem I am facing is that, the very first two forward passes take way longer time than the following passes after that. By “longer” I mean a couple of seconds, where as after the first two passes the model execution takes about ~5ms on average. I have followed the official documentations in implementing the training script but it seems to me there is a bug somewhere that causes the problem. Could you please help me out with this? Below I have attached the model and the training script. Thanks a lot! :slight_smile:

class BasicConv2d(nn.Module):

   def __init__(self, in_channels, out_channels, **kwargs):
       super(BasicConv2d, self).__init__()
       self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
       self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
       self.relu = nn.ReLU()
   def forward(self, x):
       x = self.conv(x)
       x = self.bn(x)
       x = self.relu(x)
       return x 
       # return nn.functional.relu(x, inplace=True)

# study the model in more depth
class FaceBoxesMobNet(nn.Module):

   def __init__(self, phase, num_classes=1, quantization=False):
       super(FaceBoxesMobNet, self).__init__()
       self.phase = phase
       self.num_classes = num_classes

       self.backbone = MobileNetV2()
       # self.backbone = MobileNetV3()
       # why don't we just use feature maps from the backbone
       self.avg = nn.Conv2d(1280, 128, kernel_size=1, bias=False, groups=128)
       # self.avg = nn.Conv2d(96, 128, kernel_size=1, bias=False)

       self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
       self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)

       self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
       self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)

       self.get_multibox(self.num_classes)
       self.quantization = quantization
       self.quant = QuantStub()
       self.dequant = DeQuantStub()

       # if self.phase == 'test':
       self.prob = nn.Sigmoid()

       if self.phase == 'train':
           for m in self.modules():
               if isinstance(m, nn.Conv2d):
                   if m.bias is not None:
                       nn.init.xavier_normal_(m.weight.data)
                       m.bias.data.fill_(0.02)
                   else:
                       m.weight.data.normal_(0, 0.01)
               elif isinstance(m, nn.BatchNorm2d):
                   m.weight.data.fill_(1)
                   m.bias.data.zero_()

   def get_multibox(self, num_classes):

       self.loc0 = nn.Conv2d(128, 3 * 4, kernel_size=3, padding=1)
       self.conf0 = nn.Conv2d(128, 3 * num_classes, kernel_size=3, padding=1)
       self.loc1 = nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)
       self.conf1 = nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)
       self.loc2 = nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)
       self.conf2 = nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)

       # self.loc0 = nn.Conv2d(128, 2 * 4, kernel_size=3, padding=1)
       # self.conf0 = nn.Conv2d(128, 2 * num_classes, kernel_size=3, padding=1)
       # self.loc1 = nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)
       # self.conf1 = nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)
       # self.loc2 = nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)
       # self.conf2 = nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)

   def fuse_model(self):
       for m in self.modules():
           if type(m) == ConvBNReLU:
               torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
           if type(m) == BasicConv2d:
               torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)
           if type(m) == InvertedResidual:
               for idx in range(len(m.conv)):
                   if type(m.conv[idx]) == nn.Conv2d:
                       torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

   def forward(self, x):
       # here we are doing static quantazation
       x = self.quant(x)    # <-- where tensors are quatized 

       x = self.backbone(x)
       # print('output shape of backbne: {}'.format(x.shape))
       x = self.avg(x)

       loc0 = self.loc0(x).permute(0, 2, 3, 1).contiguous()
       loc0 = loc0.view(loc0.size(0), -1)
       conf0 = self.conf0(x).permute(0, 2, 3, 1).contiguous()
       conf0 = conf0.view(conf0.size(0), -1)

       x = self.conv3_1(x)
       x = self.conv3_2(x)

       loc1 = self.loc1(x).permute(0, 2, 3, 1).contiguous()
       loc1 = loc1.view(loc1.size(0), -1)
       conf1 = self.conf1(x).permute(0, 2, 3, 1).contiguous()
       conf1 = conf1.view(conf1.size(0), -1)

       x = self.conv4_1(x)
       x = self.conv4_2(x)

       loc2 = self.loc2(x).permute(0, 2, 3, 1).contiguous()
       loc2 = loc2.view(loc1.size(0), -1)
       conf2 = self.conf2(x).permute(0, 2, 3, 1).contiguous()
       conf2 = conf2.view(conf2.size(0), -1)

       loc0 = self.dequant(loc0)
       conf0 = self.dequant(conf0)
       loc1 = self.dequant(loc1)
       conf1 = self.dequant(conf1)
       loc2 = self.dequant(loc2)
       conf2 = self.dequant(conf2)  # <-- here the tensors are converted back to floating point percisionj

       loc = torch.cat([loc0, loc1, loc2], dim=1)
       conf = torch.cat([conf0, conf1, conf2], dim=1)

       if self.phase == "test":
           output = (loc.view(loc.size(0), -1, 4),
                     self.prob(conf.view(-1, self.num_classes)),
                     None)
       else:
           output = (loc.view(loc.size(0), -1, 4),
                     conf.view(conf.size(0), -1, self.num_classes),
                     None)
       return output
def run_static_quantization(mode='per_tensor'):

    dataset_path = 'path/to/data/'
    device = 'cpu'
    model = FaceBoxesMobNet('test')
    model = model.to(device)
    model.eval()

    # model.load_state_dict(remove_prefix(torch.load(os.path.join(args.save_folder,
    #                                                             f'{experiment_name()}/{args.quantized_model}'),
    #                                                map_location=device), 'module.'))
    model.load_state_dict(torch.load(os.path.join(args.save_folder,
                                                f'{experiment_name()}/{args.quantized_model}'),
                                                   map_location=device))

    print("Loaded the model successfully. Start measuring the speed")

    s = run_speed_test(net=model, device=device)
    print(f"Speed of model before quantization: {s} ms/frame")

    print("Size of model before quantization")
    print_size_of_model(model)
    # Fuse Conv, bn and relu
    model.fuse_model()

    # Specify quantization configuration
    # Start with simple min/max range estimation and per-tensor quantization of weights
    model.qconfig = torch.quantization.default_qconfig if mode == 'per_channel' \
        else torch.quantization.get_default_qconfig('fbgemm')

    print(model.qconfig)
    torch.quantization.prepare(model, inplace=True)

    # Calibrate with the val set
    run_evaluation(net=model, device=device, dataset_path=dataset_path)
    print(f'Post Training {mode} Quantization: Calibration done')

    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    print('Post Training Quantization: Convert done')
    # print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',
    #       model.features[1].conv)

    print(f"Size of model after {mode} quantization")
    print_size_of_model(model)
    print("Performance of model after quantization")
    run_evaluation(net=model, device='cpu', dataset_path=dataset_path)

    s = run_speed_test(net=model, device=device)
    print(f"Speed of model after {mode} quantization: {s} ms/frame")

    torch.jit.save(torch.jit.script(model),
                   os.path.join(args.save_folder, f'{experiment_name()}/static_final{mode}.pth'))