nn.ConvTranspose2d with metal any way

Hi! I’m trying to run my model on ios mobile device(iphone X) with gpu support (metal). But i get some errors, so i hope for your help. This small example

import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

class TEST(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.convt = torch.nn.ConvTranspose2d(3, 3, 3)

    def forward(self, x):
        x = self.quant(x)
        x = self.convt(x)
        x = self.dequant(x)
        return x

rand_input = torch.randn(1, 3, 256, 256)
with torch.no_grad():
    model = TEST().eval()
    backend = "qnnpack" #"fbgemm" # this var also error
    torch.backends.quantized.engine = backend
    model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)
    script_module = torch.jit.script(model)
    optimized_model = optimize_for_mobile(script_module, backend='metal')
    print(torch.jit.export_opnames(optimized_model))
    optimized_model._save_for_lite_interpreter("quant_test.ptl")

output

['aten::dequantize.self', 'aten::len.t', 'aten::ne.int', 'aten::quantize_per_tensor', 'aten::size', 'prim::RaiseException', 'quantized::conv_transpose2d']

Then if i run this model

at::Tensor rand_input = torch::randn({1, 3, 256, 256}).metal();
auto outputTensor = _impl.forward({ rand_input }).toTensor().cpu();

I have output

2022-02-03 10:47:15.805418+0300 Proj-iOS[25878:25338470] Metal API Validation Enabled
2022-02-03 10:47:15.859409+0300 Proj-iOS[25878:25338470] Could not run 'aten::quantize_per_tensor' with arguments from the 'Metal' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::quantize_per_tensor' is only available for these backends: [CPU, BackendSelect, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, Functionalize].

CPU: registered at /Users/distiller/project/build_ios/aten/src/ATen/RegisterCPU.cpp:20943 [kernel]
BackendSelect: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
ADInplaceOrView: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
AutogradOther: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
AutogradXLA: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:51 [backend fallback]
AutogradLazy: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:55 [backend fallback]
AutogradXPU: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradMLC: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:59 [backend fallback]
AutogradHPU: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:68 [backend fallback]
Functionalize: registered at /Users/distiller/project/aten/src/ATen/FunctionalizeFallbackKernel.cpp:52 [backend fallback]

  
  Debug info for handle(s): debug_handles:{-1}, was not found.
  
Exception raised from reportError at /Users/distiller/project/aten/src/ATen/core/dispatch/OperatorEntry.cpp:434 (most recent call first):
frame #0: _ZNK3c104impl13OperatorEntry11reportErrorENS_11DispatchKeyE + 464 (0x10328d438 in Proj-iOS)
frame #1: _ZNK3c1010Dispatcher9callBoxedERKNS_14OperatorHandleEPNSt3__16vectorINS_6IValueENS4_9allocatorIS6_EEEE + 244 (0x10370bf74 in Proj-iOS)
frame #2: _ZN5torch3jit6mobile16InterpreterState3runERNSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 4580 (0x103717464 in Proj-iOS)
frame #3: _ZN5torch3jit6mobile8Function3runERNSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 108 (0x10370a3ac in Proj-iOS)
frame #4: _ZNK5torch3jit6mobile6Method3runERNSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 560 (0x103719d08 in Proj-iOS)
frame #5: _ZNK5torch3jit6mobile6MethodclENSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 24 (0x10371a8b4 in Proj-iOS)
frame #6: _ZN5torch3jit6mobile6Module7forwardENSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 148 (0x1038b8ef8 in Proj-iOS)
frame #7: -[InferenceModule + + (0x1038b840c in Proj-iOS)
frame #8: -[ViewController + + (0x103a350e4 in Proj-iOS)
frame #9: -[CvVideoCamera + + (0x103fddaf4 in Proj-iOS)
frame #10: CF54A5DB-6EE5-3E73-B7E1-D2D0AB7598C8 + 147436 (0x1c7081fec in AVFCapture)
frame #11: CF54A5DB-6EE5-3E73-B7E1-D2D0AB7598C8 + 146764 (0x1c7081d4c in AVFCapture)
frame #12: 2F509455-C380-3E7C-AFB5-CB0461F08A60 + 143420 (0x1c713f03c in CMCapture)
frame #13: 2F509455-C380-3E7C-AFB5-CB0461F08A60 + 2651332 (0x1c73a34c4 in CMCapture)
frame #14: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 397976 (0x1ae03f298 in libdispatch.dylib)
frame #15: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 236144 (0x1ae017a70 in libdispatch.dylib)
frame #16: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 303900 (0x1ae02831c in libdispatch.dylib)
frame #17: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 250384 (0x1ae01b210 in libdispatch.dylib)
frame #18: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 253484 (0x1ae01be2c in libdispatch.dylib)
frame #19: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 292460 (0x1ae02566c in libdispatch.dylib)
frame #20: _pthread_wqthread + 272 (0x1f68b55bc in libsystem_pthread.dylib)
frame #21: start_wqthread + 8 (0x1f68b886c in libsystem_pthread.dylib)

does this mean that I can’t run the quantized model on metal at the moment? Maybe I can have in some other way?
Also, as far as I understand, there is no way to run ConvTranspose2d on metal in any way, when I try convert without quantization, an error crashes

convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function
Thank you very much for any help!

Versions

Versions
Collecting environment information...
PyTorch version: 1.11.0.dev20220114
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.1 (x86_64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: version 3.20.5
Libc version: N/A

Python version: 3.9.9 | packaged by conda-forge | (main, Dec 20 2021, 02:41:37) [Clang 11.1.0 ] (64-bit runtime)
Python platform: macOS-12.1-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.11.0.dev20220114
[pip3] torchvision==0.12.0.dev20220119
[conda] blas 1.0 mkl
[conda] cudatoolkit 9.0 h41a26b3_0
[conda] libblas 3.9.0 12_osx64_mkl conda-forge
[conda] libcblas 3.9.0 12_osx64_mkl conda-forge
[conda] liblapack 3.9.0 12_osx64_mkl conda-forge
[conda] liblapacke 3.9.0 12_osx64_mkl conda-forge
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mkl-service 2.4.0 py39h9ed2024_0
[conda] mkl_fft 1.3.1 py39h4ab4a9b_0
[conda] mkl_random 1.2.2 py39hb2f4e1b_0
[conda] numpy 1.21.2 py39h4b4dc7a_0
[conda] numpy-base 1.21.2 py39he0bd621_0
[conda] pytorch 1.11.0.dev20220114 py3.9_0 pytorch-nightly
[conda] torchvision 0.12.0.dev20220119 py39_cpu pytorch-nightly

Install LibTorch (demo nets work correctly):
pod 'LibTorch-Lite-Nightly'

Can you help here @xta0

Hello! Is there anything known about this issue?

Yes you cannot run quantized models on metal.

What about fp32 model? What error did you get? Can you paste the full repro for that?