RuntimeError: promoteTypes with quantized numbers is not handled yet

Hi all, I’m trying to quantize a citrinet ASR model and convert it to a torchscript model. I’m getting the following error :

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_57531/577484169.py in <module>
      6 loader = AudioFileLoader(sample_rate=16000)
      7 audio = loader("/DataRW/ML1/amogha/transformer_lang_model/audio_files/1.wav")
----> 8 print(model_static_quantized.predict(audio))

/tmp/ipykernel_57531/390047189.py in predict(self, x)
    100         #x=self.quant(x)
    101         audio_lengths = torch.tensor(x.shape[0] * [x.shape[-1]], device=x.device)
--> 102         pred, _ = self(x, audio_lengths)
    103         #pred = self.dequant(pred)
    104         return self.text_transform.decode_prediction(pred.argmax(1))

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_57531/390047189.py in forward(self, x, lengths)
     86         lengths=lengths.float()
     87         lengths=self.quant(lengths)
---> 88         features, feature_lengths = self.audio_transform(x, lengths)
     89         encoded, out_lengths = self.encoder(features, feature_lengths)
     90         return self.dequant(self.decoder(encoded)), self.dequant(out_lengths)

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/thunder/blocks.py in forward(self, audio, audio_lengths)
     99     ) -> Tuple[torch.Tensor, torch.Tensor]:
    100         for module in self.children():
--> 101             audio, audio_lengths = module(audio, audio_lengths)
    102         return audio, audio_lengths
    103 

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/thunder/blocks.py in forward(self, audio, audio_lengths)
    113         self, audio: torch.Tensor, audio_lengths: torch.Tensor
    114     ) -> Tuple[torch.Tensor, torch.Tensor]:
--> 115         return self.layer(audio), audio_lengths
    116 
    117 

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/opt/conda/lib/python3.8/site-packages/thunder/quartznet/transform.py in forward(self, x)
    140         """
    141         return torch.cat(
--> 142             (x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
    143         )
    144 

RuntimeError: promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: QUInt8 Float

I added self.quant=torch.quantization.QuantStub() and self.dequant=torch.quantization.DeQuantStub() in the init of the model.
the forward function looks like this :

def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
        """Process the audio tensor to create the predictions.
        Args:
            x: Audio tensor of shape [batch_size, time]
            lengths: corresponding length of each element in the input tensor.
        Returns:
            Tensor with the predictions.
        """
        x=self.quant(x)
        lengths=lengths.float()
        lengths=self.quant(lengths)
        features, feature_lengths = self.audio_transform(x, lengths)
        encoded, out_lengths = self.encoder(features, feature_lengths)
        return self.dequant(self.decoder(encoded)), self.dequant(out_lengths)

How to fix this ? thanks.

this means some operations are performed between a quantized tensor and a floating point tensor, which is not supported. possibly self.preemph * x[:, :-1] here I think. Assuming self.preemph is a scalar value, the solution would be rewrite this with a floatfunctional.mul_scalar like
self.float_functional.mul_scalar(x[:, :-1], self.preemph) and attach a qconfig for self.float_functional