Hi all, I’m trying to quantize a citrinet ASR model and convert it to a torchscript model. I’m getting the following error :
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_57531/577484169.py in <module>
6 loader = AudioFileLoader(sample_rate=16000)
7 audio = loader("/DataRW/ML1/amogha/transformer_lang_model/audio_files/1.wav")
----> 8 print(model_static_quantized.predict(audio))
/tmp/ipykernel_57531/390047189.py in predict(self, x)
100 #x=self.quant(x)
101 audio_lengths = torch.tensor(x.shape[0] * [x.shape[-1]], device=x.device)
--> 102 pred, _ = self(x, audio_lengths)
103 #pred = self.dequant(pred)
104 return self.text_transform.decode_prediction(pred.argmax(1))
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/tmp/ipykernel_57531/390047189.py in forward(self, x, lengths)
86 lengths=lengths.float()
87 lengths=self.quant(lengths)
---> 88 features, feature_lengths = self.audio_transform(x, lengths)
89 encoded, out_lengths = self.encoder(features, feature_lengths)
90 return self.dequant(self.decoder(encoded)), self.dequant(out_lengths)
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.8/site-packages/thunder/blocks.py in forward(self, audio, audio_lengths)
99 ) -> Tuple[torch.Tensor, torch.Tensor]:
100 for module in self.children():
--> 101 audio, audio_lengths = module(audio, audio_lengths)
102 return audio, audio_lengths
103
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.8/site-packages/thunder/blocks.py in forward(self, audio, audio_lengths)
113 self, audio: torch.Tensor, audio_lengths: torch.Tensor
114 ) -> Tuple[torch.Tensor, torch.Tensor]:
--> 115 return self.layer(audio), audio_lengths
116
117
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
143
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
28 return cast(F, decorate_context)
29
/opt/conda/lib/python3.8/site-packages/thunder/quartznet/transform.py in forward(self, x)
140 """
141 return torch.cat(
--> 142 (x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
143 )
144
RuntimeError: promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: QUInt8 Float
I added self.quant=torch.quantization.QuantStub() and self.dequant=torch.quantization.DeQuantStub() in the init of the model.
the forward function looks like this :
def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""Process the audio tensor to create the predictions.
Args:
x: Audio tensor of shape [batch_size, time]
lengths: corresponding length of each element in the input tensor.
Returns:
Tensor with the predictions.
"""
x=self.quant(x)
lengths=lengths.float()
lengths=self.quant(lengths)
features, feature_lengths = self.audio_transform(x, lengths)
encoded, out_lengths = self.encoder(features, feature_lengths)
return self.dequant(self.decoder(encoded)), self.dequant(out_lengths)
How to fix this ? thanks.