I was trying the new torch.compile function when I encountered an error when compiling code that used autocast. I’m not sure if the behaviour is intended, if it is, it isn’t clear why.
The following snippet should reproduce the error.
import torch
def autocast():
with torch.autocast('cuda'):
return
autocast() # runs just fine
@torch.compile()
def opt_autocast():
with torch.autocast('cuda'):
return
opt_autocast() # throws AssertionError
The traceback, with torch._dynamo.config.verbose=True.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[23], line 14
11 with torch.autocast('cuda'):
12 return
---> 14 opt_autocast()
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:211, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
209 dynamic_ctx.__enter__()
210 try:
--> 211 return fn(*args, **kwargs)
212 finally:
213 set_eval_frame(prior)
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:332, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
329 return hijacked_callback(frame, cache_size, hooks)
331 with compile_lock:
--> 332 return callback(frame, cache_size, hooks)
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:480, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
478 counters["frames"]["total"] += 1
479 try:
--> 480 result = inner_convert(frame, cache_size, hooks)
481 counters["frames"]["ok"] += 1
482 return result
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:103, in wrap_convert_context.<locals>._fn(*args, **kwargs)
101 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
102 try:
--> 103 return fn(*args, **kwargs)
104 finally:
105 torch._C._set_grad_enabled(prior_grad_mode)
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
88 compilation_metrics[key] = []
89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
91 latency = time.time() - t0
92 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:339, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
336 global initial_grad_state
337 initial_grad_state = torch.is_grad_enabled()
--> 339 return _compile(
340 frame.f_code,
341 frame.f_globals,
342 frame.f_locals,
343 frame.f_builtins,
344 compiler_fn,
345 one_graph,
346 export,
347 hooks,
348 frame,
349 )
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:400, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
398 for attempt in itertools.count():
399 try:
--> 400 out_code = transform_code_object(code, transform)
401 orig_code_map[out_code] = code
402 break
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:341, in transform_code_object(code, transformations, safe)
338 instructions = cleaned_instructions(code, safe)
339 propagate_line_nums(instructions)
--> 341 transformations(instructions, code_options)
343 fix_vars(instructions, code_options)
345 dirty = True
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:387, in _compile.<locals>.transform(instructions, code_options)
374 nonlocal output
375 tracer = InstructionTranslator(
376 instructions,
377 code,
(...)
385 mutated_closure_cell_contents,
386 )
--> 387 tracer.run()
388 output = tracer.output
389 assert output is not None
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1687, in InstructionTranslator.run(self)
1685 def run(self):
1686 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1687 super().run()
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:538, in InstructionTranslatorBase.run(self)
533 try:
534 self.output.push_tx(self)
535 while (
536 self.instruction_pointer is not None
537 and not self.output.should_exit
--> 538 and self.step()
539 ):
540 pass
541 except BackendCompilerFailed:
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:501, in InstructionTranslatorBase.step(self)
499 if not hasattr(self, inst.opname):
500 unimplemented(f"missing: {inst.opname}")
--> 501 getattr(self, inst.opname)(inst)
503 return inst.opname != "RETURN_VALUE"
504 except BackendCompilerFailed:
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:307, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
305 reason = None
306 try:
--> 307 return inner_fn(self, inst)
308 except Unsupported as excp:
309 if self.has_backedge():
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:966, in InstructionTranslatorBase.CALL_FUNCTION(self, inst)
964 args = self.popn(inst.argval)
965 fn = self.pop()
--> 966 self.call_function(fn, args, {})
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:435, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
430 assert isinstance(kwargs, dict)
431 assert all(
432 isinstance(x, VariableTracker)
433 for x in itertools.chain(args, kwargs.values())
434 )
--> 435 self.push(fn.call_function(self, args, kwargs))
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py:304, in TorchVariable.call_function(self, tx, args, kwargs)
297 return TensorWithTFOverrideVariable(
298 unwrapped,
299 tensor_with_tf_override.orig_tensor_variable_source,
300 tensor_with_tf_override.subclass_torch_function__func,
301 tensor_with_tf_override.subclass_type,
302 )
303 elif self.value is torch.amp.autocast_mode.autocast:
--> 304 return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
305 elif self.value in (
306 torch.profiler.profile,
307 torch.profiler.record_function,
308 torch.autograd.profiler.profile,
309 torch.autograd.profiler.record_function,
310 ):
311 log.warning("Profiler will be ignored")
File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py:311, in AutocastModeVariable.create(target_values, kwargs)
306 values = target_values
307 # device_type : str,
308 # dtype : Optional[_dtype] = None,
309 # enabled : bool = True,
310 # cache_enabled : Optional[bool] = None):cache_enabled
--> 311 assert "device_type" in kwargs
312 values.append(kwargs["device_type"])
313 del kwargs["device_type"]
AssertionError:
from user code:
File "/tmp/ipykernel_38728/2263624541.py", line 11, in opt_autocast
with torch.autocast('cuda'):
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
And the environment used for testing.
python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.0.dev20221215+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.4 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.27
Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.15.0-191-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 9.1.85
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100-PCIE-32GB
GPU 1: Tesla V100-PCIE-32GB
Nvidia driver version: 470.141.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] lovely-numpy==0.2.2
[pip3] numpy==1.24.0rc2
[pip3] torch==2.0.0.dev20221215+cu116
[pip3] torchaudio==0.14.0.dev20221214+cu116
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.15.0.dev20221214+cu116
[conda] blas 1.0 mkl
[conda] lovely-numpy 0.2.2 pypi_0 pypi
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.24.0rc2 pypi_0 pypi
[conda] torch 2.0.0.dev20221215+cu116 pypi_0 pypi
[conda] torchaudio 0.14.0.dev20221214+cu116 pypi_0 pypi
[conda] torchtriton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torchvision 0.15.0.dev20221214+cu116 pypi_0 pypi