First forward pass in quantized model takes long time

Hi! I have quantized object detection model using static quantization strategy. The model got quantized and works fine. The problem I am facing is that, the very first two forward passes take way longer time than the following passes after that. By “longer” I mean a couple of seconds, where as after the first two passes the model execution takes about ~5ms on average. I have followed the official documentations in implementing the training script but it seems to me there is a bug somewhere that causes the problem. Could you please help me out with this? Below I have attached the model and the training script. Thanks a lot! :slight_smile:

class BasicConv2d(nn.Module):

   def __init__(self, in_channels, out_channels, **kwargs):
       super(BasicConv2d, self).__init__()
       self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
       self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
       self.relu = nn.ReLU()
   def forward(self, x):
       x = self.conv(x)
       x = self.bn(x)
       x = self.relu(x)
       return x 
       # return nn.functional.relu(x, inplace=True)

# study the model in more depth
class FaceBoxesMobNet(nn.Module):

   def __init__(self, phase, num_classes=1, quantization=False):
       super(FaceBoxesMobNet, self).__init__()
       self.phase = phase
       self.num_classes = num_classes

       self.backbone = MobileNetV2()
       # self.backbone = MobileNetV3()
       # why don't we just use feature maps from the backbone
       self.avg = nn.Conv2d(1280, 128, kernel_size=1, bias=False, groups=128)
       # self.avg = nn.Conv2d(96, 128, kernel_size=1, bias=False)

       self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
       self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)

       self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
       self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)

       self.get_multibox(self.num_classes)
       self.quantization = quantization
       self.quant = QuantStub()
       self.dequant = DeQuantStub()

       # if self.phase == 'test':
       self.prob = nn.Sigmoid()

       if self.phase == 'train':
           for m in self.modules():
               if isinstance(m, nn.Conv2d):
                   if m.bias is not None:
                       nn.init.xavier_normal_(m.weight.data)
                       m.bias.data.fill_(0.02)
                   else:
                       m.weight.data.normal_(0, 0.01)
               elif isinstance(m, nn.BatchNorm2d):
                   m.weight.data.fill_(1)
                   m.bias.data.zero_()

   def get_multibox(self, num_classes):

       self.loc0 = nn.Conv2d(128, 3 * 4, kernel_size=3, padding=1)
       self.conf0 = nn.Conv2d(128, 3 * num_classes, kernel_size=3, padding=1)
       self.loc1 = nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)
       self.conf1 = nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)
       self.loc2 = nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)
       self.conf2 = nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)

       # self.loc0 = nn.Conv2d(128, 2 * 4, kernel_size=3, padding=1)
       # self.conf0 = nn.Conv2d(128, 2 * num_classes, kernel_size=3, padding=1)
       # self.loc1 = nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)
       # self.conf1 = nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)
       # self.loc2 = nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)
       # self.conf2 = nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)

   def fuse_model(self):
       for m in self.modules():
           if type(m) == ConvBNReLU:
               torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
           if type(m) == BasicConv2d:
               torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)
           if type(m) == InvertedResidual:
               for idx in range(len(m.conv)):
                   if type(m.conv[idx]) == nn.Conv2d:
                       torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

   def forward(self, x):
       # here we are doing static quantazation
       x = self.quant(x)    # <-- where tensors are quatized 

       x = self.backbone(x)
       # print('output shape of backbne: {}'.format(x.shape))
       x = self.avg(x)

       loc0 = self.loc0(x).permute(0, 2, 3, 1).contiguous()
       loc0 = loc0.view(loc0.size(0), -1)
       conf0 = self.conf0(x).permute(0, 2, 3, 1).contiguous()
       conf0 = conf0.view(conf0.size(0), -1)

       x = self.conv3_1(x)
       x = self.conv3_2(x)

       loc1 = self.loc1(x).permute(0, 2, 3, 1).contiguous()
       loc1 = loc1.view(loc1.size(0), -1)
       conf1 = self.conf1(x).permute(0, 2, 3, 1).contiguous()
       conf1 = conf1.view(conf1.size(0), -1)

       x = self.conv4_1(x)
       x = self.conv4_2(x)

       loc2 = self.loc2(x).permute(0, 2, 3, 1).contiguous()
       loc2 = loc2.view(loc1.size(0), -1)
       conf2 = self.conf2(x).permute(0, 2, 3, 1).contiguous()
       conf2 = conf2.view(conf2.size(0), -1)

       loc0 = self.dequant(loc0)
       conf0 = self.dequant(conf0)
       loc1 = self.dequant(loc1)
       conf1 = self.dequant(conf1)
       loc2 = self.dequant(loc2)
       conf2 = self.dequant(conf2)  # <-- here the tensors are converted back to floating point percisionj

       loc = torch.cat([loc0, loc1, loc2], dim=1)
       conf = torch.cat([conf0, conf1, conf2], dim=1)

       if self.phase == "test":
           output = (loc.view(loc.size(0), -1, 4),
                     self.prob(conf.view(-1, self.num_classes)),
                     None)
       else:
           output = (loc.view(loc.size(0), -1, 4),
                     conf.view(conf.size(0), -1, self.num_classes),
                     None)
       return output
def run_static_quantization(mode='per_tensor'):

    dataset_path = 'path/to/data/'
    device = 'cpu'
    model = FaceBoxesMobNet('test')
    model = model.to(device)
    model.eval()

    # model.load_state_dict(remove_prefix(torch.load(os.path.join(args.save_folder,
    #                                                             f'{experiment_name()}/{args.quantized_model}'),
    #                                                map_location=device), 'module.'))
    model.load_state_dict(torch.load(os.path.join(args.save_folder,
                                                f'{experiment_name()}/{args.quantized_model}'),
                                                   map_location=device))

    print("Loaded the model successfully. Start measuring the speed")

    s = run_speed_test(net=model, device=device)
    print(f"Speed of model before quantization: {s} ms/frame")

    print("Size of model before quantization")
    print_size_of_model(model)
    # Fuse Conv, bn and relu
    model.fuse_model()

    # Specify quantization configuration
    # Start with simple min/max range estimation and per-tensor quantization of weights
    model.qconfig = torch.quantization.default_qconfig if mode == 'per_channel' \
        else torch.quantization.get_default_qconfig('fbgemm')

    print(model.qconfig)
    torch.quantization.prepare(model, inplace=True)

    # Calibrate with the val set
    run_evaluation(net=model, device=device, dataset_path=dataset_path)
    print(f'Post Training {mode} Quantization: Calibration done')

    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    print('Post Training Quantization: Convert done')
    # print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',
    #       model.features[1].conv)

    print(f"Size of model after {mode} quantization")
    print_size_of_model(model)
    print("Performance of model after quantization")
    run_evaluation(net=model, device='cpu', dataset_path=dataset_path)

    s = run_speed_test(net=model, device=device)
    print(f"Speed of model after {mode} quantization: {s} ms/frame")

    torch.jit.save(torch.jit.script(model),
                   os.path.join(args.save_folder, f'{experiment_name()}/static_final{mode}.pth'))

could you share the code for run_speed_test?

def run_speed_test(weights=None, net=None, device='cpu', iters=1000):
    image = torch.rand(480, 640).numpy()
    predictor = Predictor(weights=weights, net=net, device=device)

    t0 = time()
    for i in range(iters):
        x = predictor(image)

    delta_t = time() - t0
    return delta_t / float(iters) * 1000


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

Here the Pedictor is a wrapper class that does all preprocessing and post processing. The slowness of the first two passes is not caused by the wrapper class. I tested the model separately, the problem is still present.

I see, does this happen for other models as well? or just your model? is it possible to narrow down the part of the model that has this issue? (by commenting out part of the models)

Yes, I did try other models. The same behavior is present in all models. Both, the models I tried and the quatization scheme were from the official torchvision models zoo. As far as I am concerned this is a features of torch 1.8 and above. I am wondering, what benefits are from this slow initial passes and is there a way to switch that behavior off somehow?

we do not know why this happens actually, can you use PyTorch Profiler — PyTorch Tutorials 1.9.0+cu102 documentation to see the breakdown and where the slowness comes from?