Using torch compile with autocast

I was trying the new torch.compile function when I encountered an error when compiling code that used autocast. I’m not sure if the behaviour is intended, if it is, it isn’t clear why.

The following snippet should reproduce the error.

import torch
    
def autocast():
    with torch.autocast('cuda'):
        return
    
autocast() # runs just fine

@torch.compile()
def opt_autocast():
    with torch.autocast('cuda'):
        return
    
opt_autocast() # throws AssertionError

The traceback, with torch._dynamo.config.verbose=True.

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[23], line 14
     11     with torch.autocast('cuda'):
     12         return
---> 14 opt_autocast()

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:211, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    209 dynamic_ctx.__enter__()
    210 try:
--> 211     return fn(*args, **kwargs)
    212 finally:
    213     set_eval_frame(prior)

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:332, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
    329             return hijacked_callback(frame, cache_size, hooks)
    331 with compile_lock:
--> 332     return callback(frame, cache_size, hooks)

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:480, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
    478 counters["frames"]["total"] += 1
    479 try:
--> 480     result = inner_convert(frame, cache_size, hooks)
    481     counters["frames"]["ok"] += 1
    482     return result

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:103, in wrap_convert_context.<locals>._fn(*args, **kwargs)
    101 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    102 try:
--> 103     return fn(*args, **kwargs)
    104 finally:
    105     torch._C._set_grad_enabled(prior_grad_mode)

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/utils.py:90, in dynamo_timed.<locals>.time_wrapper(*args, **kwargs)
     88     compilation_metrics[key] = []
     89 t0 = time.time()
---> 90 r = func(*args, **kwargs)
     91 latency = time.time() - t0
     92 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:339, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
    336 global initial_grad_state
    337 initial_grad_state = torch.is_grad_enabled()
--> 339 return _compile(
    340     frame.f_code,
    341     frame.f_globals,
    342     frame.f_locals,
    343     frame.f_builtins,
    344     compiler_fn,
    345     one_graph,
    346     export,
    347     hooks,
    348     frame,
    349 )

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:400, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    398 for attempt in itertools.count():
    399     try:
--> 400         out_code = transform_code_object(code, transform)
    401         orig_code_map[out_code] = code
    402         break

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:341, in transform_code_object(code, transformations, safe)
    338 instructions = cleaned_instructions(code, safe)
    339 propagate_line_nums(instructions)
--> 341 transformations(instructions, code_options)
    343 fix_vars(instructions, code_options)
    345 dirty = True

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:387, in _compile.<locals>.transform(instructions, code_options)
    374 nonlocal output
    375 tracer = InstructionTranslator(
    376     instructions,
    377     code,
   (...)
    385     mutated_closure_cell_contents,
    386 )
--> 387 tracer.run()
    388 output = tracer.output
    389 assert output is not None

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1687, in InstructionTranslator.run(self)
   1685 def run(self):
   1686     _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1687     super().run()

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:538, in InstructionTranslatorBase.run(self)
    533 try:
    534     self.output.push_tx(self)
    535     while (
    536         self.instruction_pointer is not None
    537         and not self.output.should_exit
--> 538         and self.step()
    539     ):
    540         pass
    541 except BackendCompilerFailed:

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:501, in InstructionTranslatorBase.step(self)
    499     if not hasattr(self, inst.opname):
    500         unimplemented(f"missing: {inst.opname}")
--> 501     getattr(self, inst.opname)(inst)
    503     return inst.opname != "RETURN_VALUE"
    504 except BackendCompilerFailed:

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:307, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    305 reason = None
    306 try:
--> 307     return inner_fn(self, inst)
    308 except Unsupported as excp:
    309     if self.has_backedge():

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:966, in InstructionTranslatorBase.CALL_FUNCTION(self, inst)
    964 args = self.popn(inst.argval)
    965 fn = self.pop()
--> 966 self.call_function(fn, args, {})

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:435, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    430 assert isinstance(kwargs, dict)
    431 assert all(
    432     isinstance(x, VariableTracker)
    433     for x in itertools.chain(args, kwargs.values())
    434 )
--> 435 self.push(fn.call_function(self, args, kwargs))

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py:304, in TorchVariable.call_function(self, tx, args, kwargs)
    297     return TensorWithTFOverrideVariable(
    298         unwrapped,
    299         tensor_with_tf_override.orig_tensor_variable_source,
    300         tensor_with_tf_override.subclass_torch_function__func,
    301         tensor_with_tf_override.subclass_type,
    302     )
    303 elif self.value is torch.amp.autocast_mode.autocast:
--> 304     return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
    305 elif self.value in (
    306     torch.profiler.profile,
    307     torch.profiler.record_function,
    308     torch.autograd.profiler.profile,
    309     torch.autograd.profiler.record_function,
    310 ):
    311     log.warning("Profiler will be ignored")

File ~/miniconda3/envs/ds/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py:311, in AutocastModeVariable.create(target_values, kwargs)
    306 values = target_values
    307 # device_type : str,
    308 # dtype : Optional[_dtype] = None,
    309 # enabled : bool = True,
    310 # cache_enabled : Optional[bool] = None):cache_enabled
--> 311 assert "device_type" in kwargs
    312 values.append(kwargs["device_type"])
    313 del kwargs["device_type"]

AssertionError: 

from user code:
   File "/tmp/ipykernel_38728/2263624541.py", line 11, in opt_autocast
    with torch.autocast('cuda'):


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

And the environment used for testing.

python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.0.dev20221215+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.4 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.27

Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.15.0-191-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 9.1.85
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100-PCIE-32GB
GPU 1: Tesla V100-PCIE-32GB

Nvidia driver version: 470.141.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] lovely-numpy==0.2.2
[pip3] numpy==1.24.0rc2
[pip3] torch==2.0.0.dev20221215+cu116
[pip3] torchaudio==0.14.0.dev20221214+cu116
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.15.0.dev20221214+cu116
[conda] blas                      1.0                         mkl
[conda] lovely-numpy              0.2.2                    pypi_0    pypi
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0           py310h7f8727e_0
[conda] mkl_fft                   1.3.1           py310hd6ae3a3_0
[conda] mkl_random                1.2.2           py310h00e6091_0
[conda] numpy                     1.24.0rc2                pypi_0    pypi
[conda] torch                     2.0.0.dev20221215+cu116          pypi_0    pypi
[conda] torchaudio                0.14.0.dev20221214+cu116          pypi_0    pypi
[conda] torchtriton               2.0.0+0d7e753227          pypi_0    pypi
[conda] torchvision               0.15.0.dev20221214+cu116          pypi_0    pypi

I guess torch.compile is unhappy about the positional argument and might expect a keyword argument.
This works for me:

@torch.compile()
def opt_autocast():
    with torch.autocast(device_type='cuda'):
        return

opt_autocast()