nn.BatchNorm2d with metal

Hi! I’m trying to run my model on ios mobile device(iphone X) with gpu support (metal). But i get some errors, so i hope for your help. This small example

import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

class TEST(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3,3, 3)
        self.bn = nn.BatchNorm2d(3)

    def forward(self, x):
        return self.bn(self.conv(x))


rand_input = torch.randn(1, 3, 256, 256)
with torch.no_grad():
    export_model = TEST().eval().cpu()
    print(export_model(rand_input).shape, export_model(rand_input).dtype)

    traced_script_module = torch.jit.trace(export_model, rand_input.cpu()) ## both options cannot work
    #traced_script_module = torch.jit.script(export_model) ## both options cannot work
    optimized_model = optimize_for_mobile(traced_script_module, backend='metal')
    print(torch.jit.export_opnames(optimized_model))
    optimized_model._save_for_lite_interpreter("traced_test.ptl")

I have ouyput:

torch.Size([1, 3, 254, 254]) torch.float32
['metal_prepack::conv2d_run']

And the example works correctly in this code.

at::Tensor rand_input = torch::randn({1,3,256,256}).metal();
auto outputTensor = _impl.forward({ rand_input }).toTensor().cpu();

but if i change return in model

return self.bn(x)

or use nn.InstanceNorm2d(3) instead of nn.BatchNorm2d(3)
i get this output

torch.Size([1, 3, 254, 254]) torch.float32
['aten::instance_norm', 'metal_prepack::conv2d_run']

and my code get error:

2022-02-03 10:26:48.591695+0300 Proj-iOS[25863:25332632] Metal API Validation Enabled
2022-02-03 10:26:49.017366+0300 Proj-iOS[25863:25332632] Could not run 'aten::native_batch_norm' with arguments from the 'Metal' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::native_batch_norm' is only available for these backends: [CPU, BackendSelect, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, Functionalize].

CPU: registered at /Users/distiller/project/build_ios/aten/src/ATen/RegisterCPU.cpp:20943 [kernel]
BackendSelect: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
ADInplaceOrView: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
AutogradOther: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
AutogradXLA: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:51 [backend fallback]
AutogradLazy: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:55 [backend fallback]
AutogradXPU: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradMLC: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:59 [backend fallback]
AutogradHPU: fallthrough registered at /Users/distiller/project/aten/src/ATen/core/VariableFallbackKernel.cpp:68 [backend fallback]
Functionalize: registered at /Users/distiller/project/aten/src/ATen/FunctionalizeFallbackKernel.cpp:52 [backend fallback]

  
  Debug info for handle(s): debug_handles:{-1}, was not found.
  
Exception raised from reportError at /Users/distiller/project/aten/src/ATen/core/dispatch/OperatorEntry.cpp:434 (most recent call first):
frame #0: _ZNK3c104impl13OperatorEntry11reportErrorENS_11DispatchKeyE + 464 (0x103285438 in Proj-iOS)
frame #1: _ZNK3c1010Dispatcher4callINSt3__15tupleIJN2at6TensorES5_S5_EEEJRKS5_RKNS_8optionalIS5_EESC_SC_SC_bddEEET_RKNS_19TypedOperatorHandleIFSD_DpT0_EEESG_ + 296 (0x1029492c4 in Proj-iOS)
frame #2: _ZN2at6native22_batch_norm_impl_indexERKNS_6TensorERKN3c108optionalIS1_EES8_S8_S8_bddb + 1568 (0x103488980 in Proj-iOS)
frame #3: _ZNK3c1010Dispatcher4callINSt3__15tupleIJN2at6TensorES5_S5_S5_xEEEJRKS5_RKNS_8optionalIS5_EESC_SC_SC_bddbEEET_RKNS_19TypedOperatorHandleIFSD_DpT0_EEESG_ + 268 (0x10288e1ac in Proj-iOS)
frame #4: _ZN2at4_ops22_batch_norm_impl_index4callERKNS_6TensorERKN3c108optionalIS2_EES9_S9_S9_bddb + 144 (0x10283ef60 in Proj-iOS)
frame #5: _ZN2at6native10batch_normERKNS_6TensorERKN3c108optionalIS1_EES8_S8_S8_bddb + 548 (0x10348a0ec in Proj-iOS)
frame #6: _ZNK3c1010Dispatcher4callIN2at6TensorEJRKS3_RKNS_8optionalIS3_EES9_S9_S9_bddbEEET_RKNS_19TypedOperatorHandleIFSA_DpT0_EEESD_ + 268 (0x1029cfe80 in Proj-iOS)
frame #7: _ZN2at4_ops10batch_norm4callERKNS_6TensorERKN3c108optionalIS2_EES9_S9_S9_bddb + 144 (0x102983cac in Proj-iOS)
frame #8: _ZN2at6native13instance_normERKNS_6TensorERKN3c108optionalIS1_EES8_S8_S8_bddb + 976 (0x10348a5f8 in Proj-iOS)
frame #9: _ZN3c104impl34call_functor_with_args_from_stack_INS0_6detail31WrapFunctionIntoRuntimeFunctor_IPFN2at6TensorERKS5_RKNS_8optionalIS5_EESB_SB_SB_bddbES5_NS_4guts8typelist8typelistIJS7_SB_SB_SB_SB_bddbEEEEELb0EJLm0ELm1ELm2ELm3ELm4ELm5ELm6ELm7ELm8EEJS7_SB_SB_SB_SB_bddbEEENSt3__15decayINSE_21infer_function_traitsIT_E4type11return_typeEE4typeEPNS_14OperatorKernelENS_14DispatchKeySetEPNSJ_6vectorINS_6IValueENSJ_9allocatorISV_EEEENSJ_16integer_sequenceImJXspT1_EEEEPNSG_IJDpT2_EEE + 220 (0x102c2b9f4 in Proj-iOS)
frame #10: _ZN3c104impl31make_boxed_from_unboxed_functorINS0_6detail31WrapFunctionIntoRuntimeFunctor_IPFN2at6TensorERKS5_RKNS_8optionalIS5_EESB_SB_SB_bddbES5_NS_4guts8typelist8typelistIJS7_SB_SB_SB_SB_bddbEEEEELb0EE4callEPNS_14OperatorKernelERKNS_14OperatorHandleENS_14DispatchKeySetEPNSt3__16vectorINS_6IValueENSQ_9allocatorISS_EEEE + 40 (0x102c2b8a4 in Proj-iOS)
frame #11: _ZNK3c1010Dispatcher9callBoxedERKNS_14OperatorHandleEPNSt3__16vectorINS_6IValueENS4_9allocatorIS6_EEEE + 164 (0x103703f24 in Proj-iOS)
frame #12: _ZN5torch3jit6mobile16InterpreterState3runERNSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 4580 (0x10370f464 in Proj-iOS)
frame #13: _ZN5torch3jit6mobile8Function3runERNSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 108 (0x1037023ac in Proj-iOS)
frame #14: _ZNK5torch3jit6mobile6Method3runERNSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 560 (0x103711d08 in Proj-iOS)
frame #15: _ZNK5torch3jit6mobile6MethodclENSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 24 (0x1037128b4 in Proj-iOS)
frame #16: _ZN5torch3jit6mobile6Module7forwardENSt3__16vectorIN3c106IValueENS3_9allocatorIS6_EEEE + 148 (0x1038b0ef8 in Proj-iOS)
frame #17: -[InferenceModule + + (0x1038b040c in Proj-iOS)
frame #18: -[ViewController + + (0x103a2d0e4 in Proj-iOS)
frame #19: -[CvVideoCamera + + (0x103fd5af4 in Proj-iOS)
frame #20: CF54A5DB-6EE5-3E73-B7E1-D2D0AB7598C8 + 147436 (0x1c7081fec in AVFCapture)
frame #21: CF54A5DB-6EE5-3E73-B7E1-D2D0AB7598C8 + 146764 (0x1c7081d4c in AVFCapture)
frame #22: 2F509455-C380-3E7C-AFB5-CB0461F08A60 + 143420 (0x1c713f03c in CMCapture)
frame #23: 2F509455-C380-3E7C-AFB5-CB0461F08A60 + 2651332 (0x1c73a34c4 in CMCapture)
frame #24: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 397976 (0x1ae03f298 in libdispatch.dylib)
frame #25: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 236144 (0x1ae017a70 in libdispatch.dylib)
frame #26: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 303900 (0x1ae02831c in libdispatch.dylib)
frame #27: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 250384 (0x1ae01b210 in libdispatch.dylib)
frame #28: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 253484 (0x1ae01be2c in libdispatch.dylib)
frame #29: 03AD11F9-67AE-3219-ACA3-DF0A9AF629D4 + 292460 (0x1ae02566c in libdispatch.dylib)
frame #30: _pthread_wqthread + 272 (0x1f68b55bc in libsystem_pthread.dylib)
frame #31: start_wqthread + 8 (0x1f68b886c in libsystem_pthread.dylib)

I understand that most likely the problem is that if BN is called after CONV, it combines these layers, so it work.Is it mean, I can’t use normalization in any other way except after convolution?
Thank you very much for any help!

Versions

Collecting environment information...
PyTorch version: 1.11.0.dev20220114
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.1 (x86_64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: version 3.20.5
Libc version: N/A

Python version: 3.9.9 | packaged by conda-forge | (main, Dec 20 2021, 02:41:37) [Clang 11.1.0 ] (64-bit runtime)
Python platform: macOS-12.1-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.11.0.dev20220114
[pip3] torchvision==0.12.0.dev20220119
[conda] blas 1.0 mkl
[conda] cudatoolkit 9.0 h41a26b3_0
[conda] libblas 3.9.0 12_osx64_mkl conda-forge
[conda] libcblas 3.9.0 12_osx64_mkl conda-forge
[conda] liblapack 3.9.0 12_osx64_mkl conda-forge
[conda] liblapacke 3.9.0 12_osx64_mkl conda-forge
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mkl-service 2.4.0 py39h9ed2024_0
[conda] mkl_fft 1.3.1 py39h4ab4a9b_0
[conda] mkl_random 1.2.2 py39hb2f4e1b_0
[conda] numpy 1.21.2 py39h4b4dc7a_0
[conda] numpy-base 1.21.2 py39he0bd621_0
[conda] pytorch 1.11.0.dev20220114 py3.9_0 pytorch-nightly
[conda] torchvision 0.12.0.dev20220119 py39_cpu pytorch-nightly

Install LibTorch (demo nets work correctly):
pod 'LibTorch-Lite-Nightly'

@xta0 can you help here