When creating a torch.tensor
in FP8 precision by torch.empty
, I got the following:
Traceback (most recent call last):
File "/workspace/temp.py", line 13, in <module>
test_compile_torch_empty_w_2D_huge_shape() #works
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
super().run()
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
while self.step():
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2642, in RETURN_VALUE
self._return(inst)
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2627, in _return
self.output.compile_subgraph(
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1124, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/usr/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1319, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1410, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1391, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1951, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1505, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/common.py", line 69, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 954, in aot_module_simplified
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 169, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1352, in fw_compiler_base
_recursive_joint_graph_passes(model)
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 256, in _recursive_joint_graph_passes
joint_graph_passes(gm)
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/fx_passes/joint_graph.py", line 322, in joint_graph_passes
constant_fold_uniform_value(graph)
File "/usr/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/fx_passes/joint_graph.py", line 228, in constant_fold_uniform_value
cf.run()
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/constant_folding.py", line 200, in run
return super().run(initial_env=env)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/constant_folding.py", line 174, in run_node
self.add_node_replacement(node, out)
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/fx_passes/joint_graph.py", line 213, in add_node_replacement
self.node_replacements[node] = tensor.flatten()[0].item()
^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: "_local_scalar_dense_cuda" not implemented for 'Float8_e4m3fn'
While executing %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([512, 1048576],), kwargs = {dtype: torch.float8_e4m3fn, device: cuda, pin_memory: False})
Original traceback:
File "/workspace/temp.py", line 9, in test_compile_torch_empty_w_2D_huge_shape
return torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e4m3fn), torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e5m2)
But I do not get this exception in all scenarios. When creating a FP8 empty tensor with small shape, inductor works.
import torch
@torch.compile(backend="inductor")
def test_compile_torch_empty_w_2D_small_shape():
return torch.empty(512, 512, device="cuda", dtype=torch.float8_e4m3fn), torch.empty(512, 512, device="cuda", dtype=torch.float8_e5m2)
@torch.compile(backend="inductor")
def test_compile_torch_empty_w_2D_huge_shape():
return torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e4m3fn), torch.empty(512, 1024*1024, device="cuda", dtype=torch.float8_e5m2)
test_compile_torch_empty_w_2D_small_shape() # works
print("testing compile torch empty with 2D small shape works")
test_compile_torch_empty_w_2D_huge_shape() # fails
print("testing compile torch empty with 2D huge shape works")
It’s confusing that there is no fill operation in torch.empty
, which is different from torch.ones
.
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaGetDeviceCount 0.56% 1.900ms 0.56% 1.900ms 950.107us 2
cudaGetDeviceProperties_v2 5.40% 18.277ms 5.40% 18.277ms 2.285ms 8
aten::empty 0.08% 269.176us 94.03% 317.997ms 158.998ms 2
cudaDeviceGetStreamPriorityRange 93.89% 317.520ms 93.89% 317.520ms 317.520ms 1
cudaStreamIsCapturing 0.00% 12.234us 0.00% 12.234us 12.234us 1
cudaMalloc 0.06% 195.565us 0.06% 195.565us 195.565us 1
cudaDeviceSynchronize 0.00% 10.947us 0.00% 10.947us 10.947us 1
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------
The exception is reported in PyTorch Version: 2.4.1, but there is no exception in PyTorch Version: 2.5.1.
I’d like to know how to fix this problem in PyTorch 2.4.1