Error while trying to save a quantized model using jit

I followed the steps mentioned in the pytorch quantization tutorial and tried to save the model using jit but getting an error.

Here is the model structure

from torch.quantization import QuantStub, DeQuantStub
from loss import TripletLoss

class NN1_BN_FaceNet(nn.Module):

    def __init__(self, classify=False, embedding_size = 128, device=None):

        super(NN1_BN_FaceNet, self).__init__()
        
        self.conv1 = nn.Sequential(nn.BatchNorm2d(3), nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        
        self.conv2a = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, stride=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(192), nn.ReLU(inplace=True))
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)

        self.conv3a = nn.Sequential(nn.Conv2d(192, 192, kernel_size=1, stride=1), nn.BatchNorm2d(192), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True))
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.conv4a = nn.Sequential(nn.Conv2d(384, 384, kernel_size=1, stride=1), nn.BatchNorm2d(384), nn.ReLU(inplace=True))
        self.conv4 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
        
        self.conv5a = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, stride=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
        self.conv5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))

        self.conv6a = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, stride=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
        self.conv6 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))

        self.pool4 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.fc1 = nn.Sequential(nn.Linear(256*7*7, 4*128), nn.ReLU(inplace=True), nn.Dropout())
        self.fc7128 = nn.Sequential(nn.Linear(4*128, embedding_size))
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = device
            self.to(device)

        def l2_norm(self, input):
            input_size = input.size()
            buffer = torch.pow(input, 2)
            normp = torch.sum(buffer, 1).add_(1e-10)
            norm = torch.sqrt(normp)
            _output = torch.div(input, norm.view(-1, 1).expand_as(input))
            output = _output.view(input_size)
            return output

    def freeze_all(self):
        for param in self.parameters():
            param.requires_grad = False

    def unfreeze_all(self):
        for param in self.parameters():
            param.requires_grad = True

    def freeze_fc(self):
        for param in self.fc.parameters():
            param.requires_grad = False

    def unfreeze_fc(self):
        for param in self.fc.parameters():
            param.requires_grad = True

    def freeze_only(self, freeze):
        for name, child in self.named_children():
            if name in freeze:
                for param in child.parameters():
                    param.requires_grad = False
            else:
                for param in child.parameters():
                    param.requires_grad = True

    def unfreeze_only(self, unfreeze):
        for name, child in self.named_children():
            if name in unfreeze:
                for param in child.parameters():
                    param.requires_grad = True
            else:
                for param in child.parameters():
                    param.requires_grad = False

    def forward(self, x):
        
        x = self.quant(x)
        x = self.conv1(x)
        x = self.pool1(x)
                
        x = self.conv2a(x)
        x = self.conv2(x)
        x = self.pool2(x)

        x = self.conv3a(x)
        x = self.conv3(x)
        x = self.pool3(x)

        x = self.conv4a(x)
        x = self.conv4(x)

        x = self.conv5a(x)
        x = self.conv5(x)

        x = self.conv6a(x)
        x = self.conv6(x)

        x = self.pool4(x)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc7128(x)
        x = nn.functional.normalize(x, p=2, dim=1) 
        x = self.dequant(x)       
        return x
    
    def forward_classifier(self, x):
        features = self.forward(x)
        return features

    def fuse_model(self):
        for m in self.modules():
            if type(m) == nn.Sequential:
                if type(m[0])==nn.BatchNorm2d:
                    self.conv1[0] = nn.Identity()
                    torch.quantization.fuse_modules(self.conv1, ['1', '2', '3'], inplace=True)
                elif type(m[0])==nn.Conv2d and type(m[1])==nn.BatchNorm2d and type(m[2])==nn.ReLU:
                    torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
                elif (type(m[0])==nn.Linear and len(m)>1):
                    torch.quantization.fuse_modules(m, ['0', '1'], inplace=True)
                else:       
                    print ('No fusion performed on this layer')
                    print(m)
        print('Fusion Complete')

I am getting the following error when trying to save the model using torch.jit

torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)

The error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-19-d0b6b1651050> in <module>
      6                 columns=[ 'Train_Accuracy', 'Train_Threshold', 'Test_Accuracy', 'Test_Threshold'])
      7 Accuracy_Threshold_Table.to_csv(saved_model_dir+'Acc_Thres_Train_'+str(np.round(np.mean(accuracy_train), 3))+'_Test_'+str(np.round(np.mean(accuracy_test), 3))+'.csv')
----> 8 torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_quantized_model_file_S2)
      9 # try:
     10 #     torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_quantized_model_file_S2)

~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1255     if isinstance(obj, torch.nn.Module):
   1256         obj = call_prepare_scriptable_func(obj)
-> 1257         return torch.jit._recursive.create_script_module(
   1258             obj, torch.jit._recursive.infer_methods_to_compile
   1259         )

~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    449     if not is_tracing:
    450         AttributeTypeIsSupportedChecker().check(nn_module)
--> 451     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    452 
    453 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    515     # Compile methods if necessary
    516     if concrete_type not in concrete_type_store.methods_compiled:
--> 517         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    518         # Create hooks after methods to ensure no name collisions between hooks and methods.
    519         # If done before, hooks can overshadow methods that aren't exported.
~/anaconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    366     property_rcbs = [p.resolution_callback for p in property_stubs]
    367 
--> 368     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    369 
    370 def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):

RuntimeError: 

normalize(Tensor input, float p=2., int dim=1, float eps=9.9999999999999998e-13, Tensor? out=None) -> (Tensor):
Expected a value of type 'float' for argument 'p' but instead found type 'int'.
:
  File "<ipython-input-4-129ce50ca4a2>", line 145
        
        x = self.dequant(x)
        x = nn.functional.normalize(x, p=2, dim=1)                
            ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        return x

Any suggestions would be appreciated