Unable to convert a quantized model to TorchScript

Hi,
I am following the official tutorials here and here to quantize a model but it is errors out while saving to TorchScript.

Below is the code to reproduce this error:
Step 1 - imports

import timm
import torch
import torch.nn as nn
import torchaudio.transforms as AT
import torchvision.transforms as VT
from nnAudio import features

Step 2 : Define methods as per the tutorial

def quant_prep(model_ft):
    model_ft[0].qconfig = torch.quantization.default_qat_qconfig  # Use default QAT configuration
# Step 3
    model_ft = torch.quantization.prepare_qat(model_ft, inplace=True)
    for param in model_ft.parameters():
        param.requires_grad = True
    return(model_ft)
import torchvision.models.quantization as models

model = models.resnet18(pretrained=True, progress=True, quantize=False)
num_ftrs = model.fc.in_features
model.train()
model.fuse_model()
from torch import nn

def create_combined_model(model_fe):
    # Step 1. Isolate the feature extractor.
    model_fe_features = nn.Sequential(
    model_fe.quant,  # Quantize the input
    model_fe.conv1,
    model_fe.bn1,
    model_fe.relu,
    model_fe.maxpool,
    model_fe.layer1,
    model_fe.layer2,
    model_fe.layer3,
    model_fe.layer4,
    model_fe.avgpool,
    model_fe.dequant,  # Dequantize the output
  )

    # Step 2. Create a new "head"
    new_head = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(num_ftrs, 2),
   )

  # Step 3. Combine, and don't forget the quant stubs.
    new_model = nn.Sequential(
    model_fe_features,
    nn.Flatten(1),
    new_head,)
    
    model_ft  = quant_prep(new_model)
    
    return model_ft

Step 3 - define custom “model”

class MyModel(nn.Module):
    def __init__(self,model = model,image_size = 224):
        super().__init__()
        # num_classes=0 removes the pretrained head
        self.backbone = create_combined_model(model)
        self.sizer = VT.Resize((image_size,image_size),antialias = True)
        self.spec_layer = AT.Spectrogram(n_fft = int(config.NFFT), return_complex= False,).to('cuda')
        self.batch_norm = nn.BatchNorm2d(num_features= 1)
               
        
    def forward(self, x,train = True):
        # first compute spectrogram
        spec_gram = self.spec_layer(x).to('cuda')
        spec_gram = self.batch_norm(spec_gram)
        spec_gram_nan_check = torch.isnan(spec_gram).any().item()
        assert not (spec_gram_nan_check) ,"Tensor contains NaN values after spec gram creation."
        
        with torch.no_grad():
            if train == True:
                #generate a random number and if condition is met apply aug
                ta_transformations_rndm_choice = VT.RandomChoice([AT.FrequencyMasking(freq_mask_param=100),AT.TimeMasking(time_mask_param=50)], p=[.4, .4])
                ta_transformations_rndm_apply = VT.RandomApply([AT.FrequencyMasking(freq_mask_param=50),AT.TimeMasking(time_mask_param=25)],p = .15)
                spec_gram = ta_transformations_rndm_choice(spec_gram)
                spec_gram = ta_transformations_rndm_apply(spec_gram)
                spec_gram_nan_check = torch.isnan(spec_gram).any().item()
                assert not (spec_gram_nan_check) ,"Tensor contains NaN values after augmentations  "
                
                    
        
        x = self.sizer(spec_gram.squeeze(dim = 1))
        x = x.unsqueeze(dim = 1)
              
        # then repeat channels
        del spec_gram,spec_gram_nan_check
        if DEBUG:
            print("Final shape that goes to backbone = " + str(x.shape))
                
        x = x.expand(-1, 3, -1, -1)
        x = self.backbone(x)
        backbone_op_nan_check = torch.isnan(x).any().item()
        assert not (backbone_op_nan_check) ,"Tensor contains NaN values in the backbone OP "
        pred = x
        output = {"prediction": pred }
        #print(output)
        del x , backbone_op_nan_check
        return output

Step 4 - create an instance of MyModel

model_ft_tuned_new = MyModel(model)

Step 5:

from torch.quantization import convert
model_ft_tuned_new.cpu()
model_quantized_and_trained = convert(model_ft_tuned_new, inplace=False)

Step 6:- converting to torchscript( this is where the model errors out):

torch.jit.script(model_quantized_and_trained ,"resnet_ft.pt")

The stack trace of the error is as below :

/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py:1277: UserWarning: `optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead
  warnings.warn(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_1270/2629330692.py in <module>
----> 1 torch.jit.script(model_quantized_and_trained ,"resnet_ft.pt")

/opt/conda/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1322     if isinstance(obj, torch.nn.Module):
   1323         obj = call_prepare_scriptable_func(obj)
-> 1324         return torch.jit._recursive.create_script_module(
   1325             obj, torch.jit._recursive.infer_methods_to_compile
   1326         )

/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    557     if not is_tracing:
    558         AttributeTypeIsSupportedChecker().check(nn_module)
--> 559     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    560 
    561 

/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    634     # Compile methods if necessary
    635     if concrete_type not in concrete_type_store.methods_compiled:
--> 636         create_methods_and_properties_from_stubs(
    637             concrete_type, method_stubs, property_stubs
    638         )

/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    467     property_rcbs = [p.resolution_callback for p in property_stubs]
    468 
--> 469     concrete_type._create_methods_and_properties(
    470         property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
    471     )

RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::eq.Tensor(Tensor self, Tensor other) -> Tensor:
  Expected a value of type 'Tensor' for argument 'other' but instead found type 'bool'.
  
  aten::eq.Scalar(Tensor self, Scalar other) -> Tensor:
  Expected a value of type 'number' for argument 'other' but instead found type 'bool'.
  
  aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'number' for argument 'other' but instead found type 'bool'.
  
  aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'Tensor' for argument 'other' but instead found type 'bool'.
  
  aten::eq.int_list(int[] a, int[] b) -> bool:
  Expected a value of type 'List[int]' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.device(Device a, Device b) -> bool:
  Expected a value of type 'Device' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.bool(bool a, bool b) -> bool:
  Expected a value of type 'bool' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool:
  Expected a value of type 'AnyEnumType' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.int(int a, int b) -> bool:
  Expected a value of type 'int' for argument 'b' but instead found type 'bool'.
  
  aten::eq.complex(complex a, complex b) -> bool:
  Expected a value of type 'complex' for argument 'b' but instead found type 'bool'.
  
  aten::eq.float(float a, float b) -> bool:
  Expected a value of type 'float' for argument 'b' but instead found type 'bool'.
  
  aten::eq.int_float(int a, float b) -> bool:
  Expected a value of type 'float' for argument 'b' but instead found type 'bool'.
  
  aten::eq.float_int(float a, int b) -> bool:
  Expected a value of type 'int' for argument 'b' but instead found type 'bool'.
  
  aten::eq.float_complex(float a, complex b) -> bool:
  Expected a value of type 'complex' for argument 'b' but instead found type 'bool'.
  
  aten::eq.complex_float(complex a, float b) -> bool:
  Expected a value of type 'float' for argument 'b' but instead found type 'bool'.
  
  aten::eq(Scalar a, Scalar b) -> bool:
  Expected a value of type 'number' for argument 'b' but instead found type 'bool'.
  
  aten::eq.str(str a, str b) -> bool:
  Expected a value of type 'str' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.float_list(float[] a, float[] b) -> bool:
  Expected a value of type 'List[float]' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.Tensor_list(Tensor[] a, Tensor[] b) -> bool:
  Expected a value of type 'List[Tensor]' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.bool_list(bool[] a, bool[] b) -> bool:
  Expected a value of type 'List[bool]' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::eq.str_list(str[] a, str[] b) -> bool:
  Expected a value of type 'List[str]' for argument 'a' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'a' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  eq(float a, Tensor b) -> Tensor:
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'bool'.
  
  eq(int a, Tensor b) -> Tensor:
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'bool'.
  
  eq(complex a, Tensor b) -> Tensor:
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'bool'.

The original call is:
  File "/tmp/ipykernel_1270/2151903598.py", line 33
        
        with torch.no_grad():
            if train == True:
               ~~~~~~~~~~~~~ <--- HERE
                #generate a random number and if condition is met apply aug
                ta_transformations_rndm_choice = VT.RandomChoice([AT.FrequencyMasking(freq_mask_param=100),AT.TimeMasking(time_mask_param=50)], p=[.4, .4])

it’s probably because this is just not scriptable, could you remove the if train == True code?

Thanks Jerry. It started working after removing the if train == True part.
Appreciate your help.

1 Like