Inference on Android: Could not run 'quantized::conv1d'

I just compiled the real time speech enhancement model from
https://github.com/facebookresearch/denoiser
using torchscript and wanted to test if it can run in real-time on Android devices.

Following the official tutorial, I used this code for compilation:

from denoiser.pretrained import dns48
from denoiser.demucs import DemucsStreamer
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

model = dns48()
model.eval()
streamer = DemucsStreamer(model)
streamer.eval()
streamer.qconfig = torch.quantization.get_default_qconfig("qnnpack")
torch.quantization.prepare(streamer, inplace=True)
streamer = torch.quantization.convert(streamer, inplace=True)
torchscript_model = torch.jit.script(streamer)
optimized_model = optimize_for_mobile(torchscript_model)

optimized_model._save_for_lite_interpreter("denoiser_dns48_quantized.ptl")

And of course I had to slightly modify the denoiser code (add some type hints, etc.) to make it compile. Without quantization, it runs but it is way too slow. After adding the quantization step I got this error:

com.facebook.jni.CppException: Could not run 'quantized::conv1d' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv1d' is only available for these backends: [QuantizedCPU, BackendSelect, Functionalize, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy].

I am not sure if my backend falls into QuantizedCPU, or if it is really not supported. Is there anything I can do?

So what this is sasying is that somehow the quantized conv1d is getting FP32 tensor as input instead of quantized one. You might wanna take a look at the quantized torchscript graph (do m = torch.jit.load(...), pring(m.graph)) and see what is the input to conv1d

Here is an excerpt from the graph including all the inputs to the first conv1d operation

%self.demucs.chin : int = prim::Constant[value=1]()
%35 : int[] = prim::Constant[value=[56]]()
%41 : int[] = prim::Constant[value=[1]]()
%52 : NoneType = prim::Constant()
%kernel.5 : Tensor = aten::to(%kernel.3, %padded_frame.1, %self.demucs.lstm.lstm.training, %self.demucs.lstm.lstm.training, %52) # denoiser/resample.py:45:13
%150 : Tensor = aten::view(%padded_frame.1, %149) # denoiser/resample.py:46:19
%151 : Tensor = aten::conv1d(%150, %kernel.5, %52, %41, %35, %41, %self.demucs.chin) # denoiser/resample.py:46:10

To be honest, I have no clue what’s going on here, the model is pretty complex.

I have also found this list of supported operations https://github.com/pytorch/pytorch, which does suggest that the support for Android is actually extremely limited.

Can you get equivalent graph snapshot that includes quantized::conv1d? THe above one has aten::conv1d which FP conv.

I’m sorry I sent you the wrong graph. Here it is:

%x.651 : Tensor, %prev.223 : Tensor = prim::If(%179) # denoiser/demucs.py:376:16
  block0():
    %conv_state.3 : Tensor[] = prim::GetAttr[name="conv_state"](%self)
    %prev.3 : Tensor = aten::pop(%conv_state.3, %self.demucs.decoder.0.0.zero_point) # denoiser/demucs.py:377:27
    %prev.9 : Tensor = aten::slice(%prev.3, %37, %self.resample_buffer, %36, %self.demucs.chin) # denoiser/demucs.py:378:27
    %185 : int = aten::sub(%length.1, %8) # denoiser/demucs.py:379:27
    %186 : int = aten::floordiv(%185, %resample.1) # denoiser/demucs.py:379:27
    %tgt.1 : int = aten::add(%186, %self.demucs.chin) # denoiser/demucs.py:379:27
    %188 : int[] = aten::size(%prev.9) # <string>:13:9
    %189 : int = aten::__getitem__(%188, %37) # denoiser/demucs.py:380:36
    %missing.1 : int = aten::sub(%tgt.1, %189) # denoiser/demucs.py:380:30
    %191 : int = aten::sub(%missing.1, %self.demucs.chin) # denoiser/demucs.py:381:76
    %192 : int = aten::mul(%191, %resample.1) # denoiser/demucs.py:381:59
    %offset.1 : int = aten::sub(%185, %192) # denoiser/demucs.py:381:29
    %x.406 : Tensor = aten::slice(%x.1, %37, %offset.1, %36, %self.demucs.chin) # denoiser/demucs.py:382:24
    -> (%x.406, %prev.9)
  block1():
    -> (%x.1, %10)
%self.demucs.encoder.0.0._packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant[value=object(0x55ccf8db1390)]()
%self.demucs.decoder.0.0.scale : float = prim::Constant[value=1.]()
%self.demucs.decoder.0.0.zero_point : int = prim::Constant[value=0]()
%198 : Tensor = quantized::conv1d(%x.651, %self.demucs.encoder.0.0._packed_params, %self.demucs.decoder.0.0.scale, %self.demucs.decoder.0.0.zero_point) # venv/lib/python3.8/site-packages/torch/ao/nn/quantized/modules/conv.py:369:15

I also want to note that I did not add quantization layers to the model

torch.quantization.QuantStub()
torch.quantization.DeQuantStub()

I still can’t get the whole model to work but I feel like that may have caused the error.

How did you generate quantized model? @jerryzh168

this is calling the eager mode quantization so QuantStub/DeQuantStub is needed to make sure the inputs to quantized::covn1d is quantized