FP8 `torch.empty` doesn't work under `inductor` of pytorch 2.4.1

When creating a torch.tensor in FP8 precision by torch.empty, I got the following:

Traceback (most recent call last):
  File "/workspace/temp.py", line 13, in <module>
    test_compile_torch_empty_w_2D_huge_shape() #works
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
           ^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2642, in RETURN_VALUE
    self._return(inst)
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2627, in _return
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1124, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1319, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1410, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1391, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1951, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1505, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/common.py", line 69, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 954, in aot_module_simplified
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 169, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1352, in fw_compiler_base
    _recursive_joint_graph_passes(model)
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 256, in _recursive_joint_graph_passes
    joint_graph_passes(gm)
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/fx_passes/joint_graph.py", line 322, in joint_graph_passes
    constant_fold_uniform_value(graph)
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/fx_passes/joint_graph.py", line 228, in constant_fold_uniform_value
    cf.run()
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/constant_folding.py", line 200, in run
    return super().run(initial_env=env)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/constant_folding.py", line 174, in run_node
    self.add_node_replacement(node, out)
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/fx_passes/joint_graph.py", line 213, in add_node_replacement
    self.node_replacements[node] = tensor.flatten()[0].item()
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: "_local_scalar_dense_cuda" not implemented for 'Float8_e4m3fn'

While executing %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([512, 1048576],), kwargs = {dtype: torch.float8_e4m3fn, device: cuda, pin_memory: False})
Original traceback:
  File "/workspace/temp.py", line 9, in test_compile_torch_empty_w_2D_huge_shape
    return torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e4m3fn), torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e5m2)

But I do not get this exception in all scenarios. When creating a FP8 empty tensor with small shape, inductor works.

import torch

@torch.compile(backend="inductor")
def test_compile_torch_empty_w_2D_small_shape():
    return torch.empty(512, 512, device="cuda", dtype=torch.float8_e4m3fn), torch.empty(512, 512, device="cuda", dtype=torch.float8_e5m2)

@torch.compile(backend="inductor")
def test_compile_torch_empty_w_2D_huge_shape():
    return torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e4m3fn), torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e5m2)

test_compile_torch_empty_w_2D_small_shape() # works
print("testing compile torch empty with 2D small shape works")
test_compile_torch_empty_w_2D_huge_shape() # fails
print("testing compile torch empty with 2D huge shape works")

It’s confusing that there is no fill operation in torch.empty, which is different from torch.ones.

------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  cudaGetDeviceCount         0.56%       1.900ms         0.56%       1.900ms     950.107us             2  
          cudaGetDeviceProperties_v2         5.40%      18.277ms         5.40%      18.277ms       2.285ms             8  
                         aten::empty         0.08%     269.176us        94.03%     317.997ms     158.998ms             2  
    cudaDeviceGetStreamPriorityRange        93.89%     317.520ms        93.89%     317.520ms     317.520ms             1  
               cudaStreamIsCapturing         0.00%      12.234us         0.00%      12.234us      12.234us             1  
                          cudaMalloc         0.06%     195.565us         0.06%     195.565us     195.565us             1  
               cudaDeviceSynchronize         0.00%      10.947us         0.00%      10.947us      10.947us             1  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  

The exception is reported in PyTorch Version: 2.4.1, but there is no exception in PyTorch Version: 2.5.1.

I’d like to know how to fix this problem in PyTorch 2.4.1

Great to hear! I would recommend sticking to the latest stable (2.7.0) or nightly binary as it seems the issue was fixed already.

Thanks for your great teamwork!

I’m sorry I can’t perform the software update at this time due to compatibility issues with the current system configuration. Would you tell me how can I fix this problem on Pytorch 2.4?