Quantization of CNN with nn.functional.normalize

I am trying to quantize a custom FaceNet model. I put in a quant and dequant block at the begining and end. But not sure if I did it properly as there is a nn.functional.normalize at the end which is getting integer values now instead of float. After running the fuse_model function and then performing evaluation I am getting bad accuracy( 50%compared to previous 82%). The model structure is

from torch.quantization import QuantStub, DeQuantStub
from loss import TripletLoss

class NN1_BN_FaceNet(nn.Module):

    def __init__(self, classify=False, embedding_size = 128, device=None):

        super(NN1_BN_FaceNet, self).__init__()
        
        self.conv1 = nn.Sequential(nn.BatchNorm2d(3), nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        
        self.conv2a = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, stride=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(192), nn.ReLU(inplace=True))
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)

        self.conv3a = nn.Sequential(nn.Conv2d(192, 192, kernel_size=1, stride=1), nn.BatchNorm2d(192), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True))
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.conv4a = nn.Sequential(nn.Conv2d(384, 384, kernel_size=1, stride=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True))
        self.conv4 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
        
        self.conv5a = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, stride=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
        self.conv5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))

        self.conv6a = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, stride=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
        self.conv6 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))

        self.pool4 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.fc1 = nn.Sequential(nn.Linear(256*7*7, 4*128), nn.ReLU(inplace=True), nn.Dropout())
        self.fc7128 = nn.Sequential(nn.Linear(4*128, embedding_size))
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = device
            self.to(device)

        def l2_norm(self, input):
            input_size = input.size()
            buffer = torch.pow(input, 2)
            normp = torch.sum(buffer, 1).add_(1e-10)
            norm = torch.sqrt(normp)
            _output = torch.div(input, norm.view(-1, 1).expand_as(input))
            output = _output.view(input_size)
            return output

    def freeze_all(self):
        for param in self.parameters():
            param.requires_grad = False

    def unfreeze_all(self):
        for param in self.parameters():
            param.requires_grad = True

    def freeze_fc(self):
        for param in self.fc.parameters():
            param.requires_grad = False

    def unfreeze_fc(self):
        for param in self.fc.parameters():
            param.requires_grad = True

    def freeze_only(self, freeze):
        for name, child in self.named_children():
            if name in freeze:
                for param in child.parameters():
                    param.requires_grad = False
            else:
                for param in child.parameters():
                    param.requires_grad = True

    def unfreeze_only(self, unfreeze):
        for name, child in self.named_children():
            if name in unfreeze:
                for param in child.parameters():
                    param.requires_grad = True
            else:
                for param in child.parameters():
                    param.requires_grad = False

    def forward(self, x):
        
        x = self.quant(x)
        x = self.conv1(x)
        x = self.pool1(x)
                
        x = self.conv2a(x)
        x = self.conv2(x)
        x = self.pool2(x)

        x = self.conv3a(x)
        x = self.conv3(x)
        x = self.pool3(x)

        x = self.conv4a(x)
        x = self.conv4(x)

        x = self.conv5a(x)
        x = self.conv5(x)

        x = self.conv6a(x)
        x = self.conv6(x)

        x = self.pool4(x)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc7128(x)
        x = nn.functional.normalize(x, p=2, dim=1) 
        x = self.dequant(x)       
        return x
    
    def forward_classifier(self, x):
        features = self.forward(x)
        return features

    def fuse_model(self):
        for m in self.modules():
            if type(m) == nn.Sequential:
                if type(m[0])==nn.BatchNorm2d:
                    self.conv1[0] = nn.Identity()
                    torch.quantization.fuse_modules(self.conv1, ['1', '2', '3'], inplace=True)
                elif type(m[0])==nn.Conv2d and type(m[1])==nn.BatchNorm2d and type(m[2])==nn.ReLU:
                    torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
                elif (type(m[0])==nn.Linear and len(m)>1):
                    torch.quantization.fuse_modules(m, ['0', '1'], inplace=True)
                else:       
                    print ('No fusion performed on this layer')
                    print(m)
        print('Fusion Complete')

Also I am getting the following error when trying to save the model using torch.jit

torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)

RuntimeError:
normalize(Tensor input, float p=2., int dim=1, float eps=9.9999999999999998e-13, Tensor? out=None) → (Tensor):
Expected a value of type ‘float’ for argument ‘p’ but instead found type ‘int’.
:
File “”, line 143
x = nn.functional.normalize(x, p=2, dim=1)
~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
x = self.dequant(x)
return x

Also after converting to quantize model and evaluating on test set I got the following error message. Although surprisingly no error was shown while calibrating on the training set.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-44-3be33b0a2ece> in <module>
      1 ### Accuracy of quantized model on validation Set
      2 print('Evaluating calibrated static quantized model on test set')
----> 3 eval_valid(float_model, triplet_loss, data_loaders_test, data_size_test)
      4 torch.save(model, saved_model_dir + 'S2_quant_model_per_tensor.pth')
      5 try:

<ipython-input-33-1a7e4dea2bf7> in eval_valid(model, triploss, dataloaders, data_size, phase)
     29 
     30                 # anc_embed, pos_embed and neg_embed are encoding(embedding) of image
---> 31                 anc_embed, pos_embed, neg_embed = model(anc_img), model(pos_img), model(neg_img)
     32 
     33                 # choose the semi hard negatives only for "training"

~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-4-ff9b451a7e2d> in forward(self, x)
   --> 143         x = nn.functional.normalize(x, p=2, dim=1)
    144         x = self.dequant(x)
    145         return x

~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/nn/functional.py in normalize(input, p, dim, eps, out)
   4445         return handle_torch_function(normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out)
   4446     if out is None:
-> 4447         denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
   4448         return input / denom
   4449     else:

~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/_tensor.py in norm(self, p, dim, keepdim, dtype)
    440         if has_torch_function_unary(self):
    441             return handle_torch_function(Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype)
--> 442         return torch.norm(self, p, dim, keepdim, dtype=dtype)
    443 
    444     def lu(self, pivot=True, get_infos=False):

~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/functional.py in norm(input, p, dim, keepdim, out, dtype)
   1463         if out is None:
   1464             if dtype is None:
-> 1465                 return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore[attr-defined]
   1466             else:
   1467                 return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore[attr-defined]

RuntimeError: norm(): input dtype should be either floating point or complex dtypes. Got QUInt8 instead.

Please suggest what can I do to resolve it. Thanks

Hi Avishek,

To fix the error you can try putting the dequant before x = nn.functional.normalize(...) to provide it with floating point values. Otherwise the way you’re setting up the model looks reasonable to me, including the fusion code. I’m not sure why you’re getting poor accuracy. Are you running dynamic quantization or static post training quantization? If so you could try running with quantization aware training, which often improves the accuracies.

Best,
-Andrew

Hi Andrew,

Thanks for your response. Yes I made the change as you suggested, however I was getting low accuracy of around ~53% in quantized model as opposed to around 82% in floating point model using Post Training Static Quantization as suggested in the pytorch tutorial on MobileNet.

I have attached my code here with a small dataset. You may need to change the filepaths. If you can have a look and make any recommendations I would appreciate it. I am thinking of comparing the quantized model and floating point model output vector and weights and try to figure out what is the issue.

I have also decided to move on to Quantization Aware Training and currently working on implementing that.

Thanks,
Avishek

I managed to resolve the issue of low accuracy. It was happening because I replaced the very first BatchNorm layer with an identity layer in the fuse_model function and was expecting it to work properly :sweat_smile: :sweat_smile:

However I am still getting the issue when trying to save this model using jit script using the command

torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)

The error message is

RuntimeError:
normalize(Tensor input, float p=2., int dim=1, float eps=9.9999999999999998e-13, Tensor? out=None) → (Tensor):
Expected a value of type ‘float’ for argument ‘p’ but instead found type ‘int’.
:
File “”, line 143
x = nn.functional.normalize(x, p=2, dim=1)
~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
x = self.dequant(x)
return x

you can try changing p=2 to p=2.0 probably

Wow! You are on a roll @jerryzh168. Setting it to p=2.0 solved the issue!

1 Like