Unsupported exception in fake tensor on the fake quantize operator

Hello.

I’m trying to fake-quantize my module and convert it to my backend binary, which is for simulating my customized quantization spec (e.g. int4). But, the conversion with
fake_quantize_per_tensor_affine api raised an error.

import torch
from torch.export import export

from executorch.exir import to_edge

class FakeQuantizePerTensorAffine(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        fq_x=torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255)
        return (fq_x,)

    def get_example_inputs(self):
        return (torch.randn(4),)

model = FakeQuantizePerTensorAffine()
example_inputs = model.get_example_inputs()
exported = export(model, example_inputs)
# executorch things
module_edge = to_edge(model)
Traceback (most recent call last):
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1455, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_ops.py", line 600, in __call__
    return self_._op(*args, **kwargs)
NotImplementedError: aten::fake_quantize_per_tensor_affine_cachemask: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/seongwoo/my-exir/test/pt2_to_my_test/test_pt2_to_my.py", line 57, in infer
    my_exir.pt2_to_my.convert(str(pt2_model), str(my_model))
  File "/home/seongwoo/my-exir/my_exir/pt2_to_my.py", line 16, in convert
    my_program = convert_exported_module_to_my(exported_program)
  File "/home/seongwoo/my-exir/my_exir/utils/utils.py", line 14, in convert_exported_module_to_my
    module_edge = to_edge(model)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 631, in to_edge
    program = program.run_decompositions(_default_decomposition_table())
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 84, in wrapper
    return fn(*args, **kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 480, in run_decompositions
    gm, graph_signature = aot_export_module(
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1047, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1237, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 533, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 150, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 171, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 680, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 655, in run_node
    result = super().run_node(n)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/fx/interpreter.py", line 274, in call_function
    return target(*args, **kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_ops.py", line 600, in __call__
    return self_._op(*args, **kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 420, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 893, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 963, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1457, in _dispatch_impl
    return maybe_run_unsafe_fallback(not_implemented_error)
  File "/home/seongwoo/my-exir/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1440, in maybe_run_unsafe_fallback
    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten.fake_quantize_per_tensor_affine_cachemask.default

While executing %fake_quantize_per_tensor_affine_cachemask_default : [num_users=2] = call_function[target=torch.ops.aten.fake_quantize_per_tensor_affine_cachemask.default](args = (%arg0_1, 0.1, 0, 0, 255), kwargs = {})
Original traceback:
  File "/home/seongwoo/my-exir/test/modules/single/op/fake_quantize_per_tensor_affine.py", line 9, in forward
    fq_x=torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255)


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/seongwoo/my-exir/test/pt2_to_my_test/test_pt2_to_my.py", line 165, in test_verify
    verify(classname, instance)
  File "/home/seongwoo/my-exir/test/pt2_to_my_test/test_pt2_to_my.py", line 110, in verify
    torch_result, my_result = infer(test, cls)
  File "/home/seongwoo/my-exir/test/pt2_to_my_test/test_pt2_to_my.py", line 59, in infer
    raise RuntimeError(f"{test}: pt2-to-my failed.\n\n {err}")
RuntimeError: FakeQuantizePerTensorAffine: pt2-to-my failed.

 aten.fake_quantize_per_tensor_affine_cachemask.default

While executing %fake_quantize_per_tensor_affine_cachemask_default : [num_users=2] = call_function[target=torch.ops.aten.fake_quantize_per_tensor_affine_cachemask.default](args = (%arg0_1, 0.1, 0, 0, 255), kwargs = {})
Original traceback:
  File "/home/seongwoo/my-exir/test/modules/single/op/fake_quantize_per_tensor_affine.py", line 9, in forward
    fq_x=torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255)

What I expected is similar with the onnx export. For instance, when I exporting above module, fake_quantize_per_tensor_affine is converted to QuantizeLinear and DequantizeLinear in the onnx. Likewise, I expected the conversion would result in the model that contains quantize and dequantized related Edge dialect.

Seems that fake_quantize_per_tensor_affine is not supported. Is there any other api or workaround for this case?