PyTorch 2.5.1, torch.compile error

Facing a PyTorch 2.5.1 - torch.compile error

---------------------------------------------------------------------------
SubprocException                          Traceback (most recent call last)
File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1446, in OutputGraph._call_user_compiler(self, gm)
   1445     compiler_fn = WrapperBackend(compiler_fn)
-> 1446 compiled_fn = compiler_fn(gm, self.example_inputs())
   1447 _step_logger()(logging.INFO, f"done compiler function {name}")

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py:129, in WrapBackendDebug.__call__(self, gm, example_inputs, **kwargs)
    128 else:
--> 129     compiled_gm = compiler_fn(gm, example_inputs)
    131 return compiled_gm

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/__init__.py:2234, in _TorchCompileInductorWrapper.__call__(self, model_, inputs_)
   2232 from torch._inductor.compile_fx import compile_fx
-> 2234 return compile_fx(model_, inputs_, config_patches=self.config)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:1521, in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions)
   1516 with V.set_fake_mode(fake_mode), torch._guards.tracing(
   1517     tracing_context
   1518 ), compiled_autograd.disable(), functorch_config.patch(
   1519     unlift_effect_tokens=True
   1520 ):
-> 1521     return aot_autograd(
   1522         fw_compiler=fw_compiler,
   1523         bw_compiler=bw_compiler,
   1524         inference_compiler=inference_compiler,
   1525         decompositions=decompositions,
   1526         partition_fn=partition_fn,
   1527         keep_inference_input_mutations=True,
   1528         cudagraphs=cudagraphs,
   1529     )(model_, example_inputs_)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/backends/common.py:72, in AotAutograd.__call__(self, gm, example_inputs, **kwargs)
     71 with enable_aot_logging(), patch_config:
---> 72     cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
     73     counters["aot_autograd"]["ok"] += 1

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py:1071, in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, keep_inference_input_mutations, inference_compiler, cudagraphs)
   1070 else:
-> 1071     compiled_fn = dispatch_and_compile()
   1073 if isinstance(mod, torch._dynamo.utils.GmWrapper):
   1074     # This function is called by the flatten_graph_inputs wrapper, which boxes
   1075     # the inputs so that they can be freed before the end of this scope.
   1076     # For overhead reasons, this is not the default wrapper, see comment:
   1077     # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py:1056, in aot_module_simplified.<locals>.dispatch_and_compile()
   1055 with compiled_autograd.disable():
-> 1056     compiled_fn, _ = create_aot_dispatcher_function(
   1057         functional_call,
   1058         fake_flat_args,
   1059         aot_config,
   1060         fake_mode,
   1061         shape_env,
   1062     )
   1063 return compiled_fn

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py:522, in create_aot_dispatcher_function(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)
    521 with dynamo_timed("create_aot_dispatcher_function"):
--> 522     return _create_aot_dispatcher_function(
    523         flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
    524     )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py:759, in _create_aot_dispatcher_function(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)
    757 compiler_fn = choose_dispatcher(needs_autograd, aot_config)
--> 759 compiled_fn, fw_metadata = compiler_fn(
    760     flat_fn,
    761     _dup_fake_script_obj(fake_flat_args),
    762     aot_config,
    763     fw_metadata=fw_metadata,
    764 )
    765 return compiled_fn, fw_metadata

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:588, in aot_dispatch_autograd(flat_fn, flat_args, aot_config, fw_metadata)
    587 with TracingContext.report_output_strides() as fwd_output_strides:
--> 588     compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
    590 if not hasattr(compiled_fw_func, "_boxed_call"):

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:1350, in compile_fx.<locals>.fw_compiler_base(model, example_inputs, is_inference)
   1349 with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
-> 1350     return _fw_compiler_base(model, example_inputs, is_inference)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:1421, in compile_fx.<locals>._fw_compiler_base(model, example_inputs, is_inference)
   1413     user_visible_outputs = dict.fromkeys(
   1414         n.name
   1415         for n in model_outputs[
   (...)
   1418         if isinstance(n, torch.fx.Node)
   1419     )
-> 1421 return inner_compile(
   1422     model,
   1423     example_inputs,
   1424     static_input_idxs=get_static_input_idxs(fixed),
   1425     cudagraphs=cudagraphs,
   1426     graph_id=graph_id,
   1427     is_inference=is_inference,
   1428     boxed_forward_device_index=forward_device,
   1429     user_visible_outputs=user_visible_outputs,
   1430 )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:475, in compile_fx_inner(*args, **kwargs)
    473 stack.enter_context(DebugContext())
--> 475 return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
    476     *args, **kwargs
    477 )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py:85, in wrap_compiler_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
     82 try:
     83     # Call the compiler_fn - which is either aot_autograd or inductor
     84     # with fake inputs
---> 85     inner_compiled_fn = compiler_fn(gm, example_inputs)
     86 except Exception as e:
     87     # TODO: Failures here are troublesome because no real inputs,
     88     # need a different serialization strategy

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:661, in _compile_fx_inner(gm, example_inputs, cudagraphs, static_input_idxs, is_backward, graph_id, cpp_wrapper, aot_mode, is_inference, boxed_forward_device_index, user_visible_outputs, layout_opt, extern_node_serializer)
    659             input._is_inductor_static = True  # type: ignore[attr-defined]
--> 661     compiled_graph = FxGraphCache.load(
    662         codegen_and_compile,
    663         gm,
    664         example_inputs,
    665         graph_kwargs,
    666         inputs_to_check,
    667         local=config.fx_graph_cache,
    668         remote=fx_graph_remote_cache,
    669     )
    670 else:

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/codecache.py:1370, in FxGraphCache.load(compile_fx_fn, gm, example_inputs, fx_kwargs, inputs_to_check, local, remote)
   1369 if not compiled_graph:
-> 1370     compiled_graph = compile_fx_fn(
   1371         gm, example_inputs, inputs_to_check, fx_kwargs
   1372     )
   1373 assert compiled_graph is not None

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:570, in _compile_fx_inner.<locals>.codegen_and_compile(gm, example_inputs, inputs_to_check, fx_kwargs)
    566 """
    567 This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting
    568 compiled fx graph. The metadata is saved to FXGraphCache.
    569 """
--> 570 compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
    571 if isinstance(compiled_graph, str):
    572     # We only return a string in aot mode, in which case we don't
    573     # need to do any post-compilation steps: we just return the string,
    574     # which is the filename of the compiled code.

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:878, in fx_codegen_and_compile(gm, example_inputs, cudagraphs, static_input_idxs, is_backward, graph_id, cpp_wrapper, aot_mode, is_inference, user_visible_outputs, layout_opt, extern_node_serializer)
    877 _check_triton_bf16_support(graph)
--> 878 compiled_fn = graph.compile_to_fn()
    879 num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/graph.py:1913, in GraphLowering.compile_to_fn(self)
   1912 else:
-> 1913     return self.compile_to_module().call

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/graph.py:1839, in GraphLowering.compile_to_module(self)
   1836 with dynamo_timed(
   1837     "GraphLowering.compile_to_module", phase_name="code_gen", fwd_only=False
   1838 ):
-> 1839     return self._compile_to_module()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/graph.py:1867, in GraphLowering._compile_to_module(self)
   1861     trace_structured(
   1862         "inductor_output_code",
   1863         lambda: {"filename": path},
   1864         payload_fn=lambda: code,
   1865     )
-> 1867 mod = PyCodeCache.load_by_key_path(
   1868     key,
   1869     path,
   1870     linemap=linemap,  # type: ignore[arg-type]
   1871     attrs={**self.constants, **self.torchbind_constants},
   1872 )
   1873 self.cache_key = key

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/codecache.py:2876, in PyCodeCache.load_by_key_path(cls, key, path, linemap, attrs)
   2875 if key not in cls.cache:
-> 2876     mod = _reload_python_module(key, path)
   2878     # another thread might set this first

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py:45, in _reload_python_module(key, path)
     44 mod.key = key  # type: ignore[attr-defined]
---> 45 exec(code, mod.__dict__, mod.__dict__)
     46 sys.modules[mod.__name__] = mod

File /tmp/torchinductor_cataluna84/xa/cxafrl5saryx7ancqcvnpenbt34cmndw4nbwhfaniwycstnlu7tb.py:623
    620 meta0 = {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.125, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 64, 'V_HEAD_DIM': 64, 'BLOCK_M': 32, 'BLOCK_N': 16, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}
--> 623 async_compile.wait(globals())
    624 del async_compile

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/async_compile.py:276, in AsyncCompile.wait(self, scope)
    275 try:
--> 276     scope[key] = result.result()
    277 except BrokenProcessPool as e:

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/codecache.py:3341, in TritonFuture.result(self)
   3339 if self.future is not None:
   3340     # If the worker failed this will throw an exception.
-> 3341     result = self.future.result()
   3342     assert result is None

File ~/anaconda3/envs/book-codebase/lib/python3.12/concurrent/futures/_base.py:456, in Future.result(self, timeout)
    455 elif self._state == FINISHED:
--> 456     return self.__get_result()
    457 else:

File ~/anaconda3/envs/book-codebase/lib/python3.12/concurrent/futures/_base.py:401, in Future.__get_result(self)
    400 try:
--> 401     raise self._exception
    402 finally:
    403     # Break a reference cycle with the exception in self._exception

SubprocException: An exception occurred in a subprocess:

Traceback (most recent call last):
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job
    result = job()
             ^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py", line 68, in _worker_compile_triton
    load_kernel().precompile(warm_cache_only=True)
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 428, in _precompile_config
    triton.compile(*compile_args, **compile_kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 318, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 216, in make_llir
    pm.run(mod)
IndexError: map::at


The above exception was the direct cause of the following exception:

BackendCompilerFailed                     Traceback (most recent call last)
Cell In[47], line 1
----> 1 execution_stats = [time_pytorch_function_forward_backward(prepare_function(fn), embeddings) for fn in functions.values()]
      2 execution_means = [stat[0] for stat in execution_stats]
      3 execution_stds = [stat[1] for stat in execution_stats]

Cell In[44], line 17, in time_pytorch_function_forward_backward(func, num_repeats, *input)
     15 # Warmup
     16 for _ in range(5):
---> 17     forward_backward(func, *input)
     18 torch.cuda.synchronize()
     20 times = []

Cell In[44], line 5, in forward_backward(func, embeddings)
      2 if embeddings.grad is not None:
      3     embeddings.grad.zero_()
----> 5 output = func(embeddings)
      6 loss = output.sum()
      7 loss.backward()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:465, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    460 saved_dynamic_layer_stack_depth = (
    461     torch._C._functorch.get_dynamic_layer_stack_depth()
    462 )
    464 try:
--> 465     return fn(*args, **kwargs)
    466 finally:
    467     # Restore the dynamic layer stack depth if necessary.
    468     torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
    469         saved_dynamic_layer_stack_depth
    470     )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1269, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
   1263             return hijacked_callback(
   1264                 frame, cache_entry, self.hooks, frame_state
   1265             )
   1267 with compile_lock, _disable_current_modes():
   1268     # skip=1: skip this frame
-> 1269     return self._torchdynamo_orig_callable(
   1270         frame, cache_entry, self.hooks, frame_state, skip=1
   1271     )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1064, in ConvertFrame.__call__(self, frame, cache_entry, hooks, frame_state, skip)
   1062 counters["frames"]["total"] += 1
   1063 try:
-> 1064     result = self._inner_convert(
   1065         frame, cache_entry, hooks, frame_state, skip=skip + 1
   1066     )
   1067     counters["frames"]["ok"] += 1
   1068     return result

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:526, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
    510 compile_id = CompileId(frame_id, frame_compile_id)
    512 signpost_event(
    513     "dynamo",
    514     "_convert_frame_assert._compile",
   (...)
    523     },
    524 )
--> 526 return _compile(
    527     frame.f_code,
    528     frame.f_globals,
    529     frame.f_locals,
    530     frame.f_builtins,
    531     self._torchdynamo_orig_callable,
    532     self._one_graph,
    533     self._export,
    534     self._export_constraints,
    535     hooks,
    536     cache_entry,
    537     cache_size,
    538     frame,
    539     frame_state=frame_state,
    540     compile_id=compile_id,
    541     skip=skip + 1,
    542 )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:924, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
    922 guarded_code = None
    923 try:
--> 924     guarded_code = compile_inner(code, one_graph, hooks, transform)
    925     return guarded_code
    926 except Exception as e:

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:666, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    664 with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
    665     with CompileTimeInstructionCounter.record():
--> 666         return _compile_inner(code, one_graph, hooks, transform)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_utils_internal.py:87, in compile_time_strobelight_meta.<locals>.compile_time_strobelight_meta_inner.<locals>.wrapper_function(*args, **kwargs)
     84     kwargs["skip"] = kwargs["skip"] + 1
     86 if not StrobelightCompileTimeProfiler.enabled:
---> 87     return function(*args, **kwargs)
     89 return StrobelightCompileTimeProfiler.profile_compile_time(
     90     function, phase_name, *args, **kwargs
     91 )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:699, in _compile.<locals>._compile_inner(code, one_graph, hooks, transform)
    697 CompileContext.get().attempt = attempt
    698 try:
--> 699     out_code = transform_code_object(code, transform)
    700     break
    701 except exc.RestartAnalysis as e:

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1322, in transform_code_object(code, transformations, safe)
   1319 instructions = cleaned_instructions(code, safe)
   1320 propagate_line_nums(instructions)
-> 1322 transformations(instructions, code_options)
   1323 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:219, in preserve_global_state.<locals>._fn(*args, **kwargs)
    215 exit_stack.enter_context(
    216     torch.fx._symbolic_trace._maybe_revert_all_patches()
    217 )
    218 try:
--> 219     return fn(*args, **kwargs)
    220 finally:
    221     cleanup.close()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:634, in _compile.<locals>.transform(instructions, code_options)
    632 try:
    633     with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 634         tracer.run()
    635 except exc.UnspecializeRestartAnalysis:
    636     speculation_log.clear()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2796, in InstructionTranslator.run(self)
   2795 def run(self):
-> 2796     super().run()

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self)
    981 try:
    982     self.output.push_tx(self)
--> 983     while self.step():
    984         pass
    985 except BackendCompilerFailed:

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self)
    892 self.update_block_stack(inst)
    894 try:
--> 895     self.dispatch_table[inst.opcode](self, inst)
    896     return not self.output.should_exit
    897 except exc.ObservedException as e:

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2987, in InstructionTranslator.RETURN_VALUE(self, inst)
   2986 def RETURN_VALUE(self, inst):
-> 2987     self._return(inst)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2972, in InstructionTranslator._return(self, inst)
   2967 _step_logger()(
   2968     logging.INFO,
   2969     f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})",
   2970 )
   2971 log.debug("%s triggered compile", inst.opname)
-> 2972 self.output.compile_subgraph(
   2973     self,
   2974     reason=GraphCompileReason(
   2975         "return_value", [self.frame_summary()], graph_break=False
   2976     ),
   2977 )
   2978 return_inst = (
   2979     create_instruction("RETURN_VALUE")
   2980     if inst.opname == "RETURN_VALUE"
   2981     else create_instruction("RETURN_CONST", argval=inst.argval)
   2982 )
   2983 self.output.add_output_instructions([return_inst])

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1117, in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
   1114 append_prefix_insts()
   1115 # optimization to generate better code in a common case
   1116 self.add_output_instructions(
-> 1117     self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
   1118     + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
   1119 )
   1120 # restore all the live local vars
   1121 self.add_output_instructions(
   1122     [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
   1123 )

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1369, in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root)
   1366     self.tracing_context.fake_mode = backend_fake_mode
   1368 with self.restore_global_state():
-> 1369     compiled_fn = self.call_user_compiler(gm)
   1371 from torch.fx._lazy_graph_module import _LazyGraphModule
   1373 if isinstance(compiled_fn, _LazyGraphModule) or (
   1374     isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
   1375     and compiled_fn.__name__ == "_lazy_forward"  # type: ignore[attr-defined]
   (...)
   1379     # this is a _LazyGraphModule. This makes it easier for dynamo to
   1380     # optimize a _LazyGraphModule.

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1416, in OutputGraph.call_user_compiler(self, gm)
   1412 def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
   1413     with dynamo_timed(
   1414         "OutputGraph.call_user_compiler", phase_name="backend_compile"
   1415     ):
-> 1416         return self._call_user_compiler(gm)

File ~/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1465, in OutputGraph._call_user_compiler(self, gm)
   1463     raise e
   1464 except Exception as e:
-> 1465     raise BackendCompilerFailed(self.compiler_fn, e) from e
   1467 signpost_event(
   1468     "dynamo",
   1469     "OutputGraph.call_user_compiler",
   (...)
   1475     },
   1476 )
   1478 return compiled_fn

BackendCompilerFailed: backend='inductor' raised:
SubprocException: An exception occurred in a subprocess:

Traceback (most recent call last):
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job
    result = job()
             ^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py", line 68, in _worker_compile_triton
    load_kernel().precompile(warm_cache_only=True)
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 428, in _precompile_config
    triton.compile(*compile_args, **compile_kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 318, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cataluna84/anaconda3/envs/book-codebase/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 216, in make_llir
    pm.run(mod)
IndexError: map::at


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I didn’t wanted to fall back to eager. Any solutions?

OS: Ubuntu 22.04.5 LTS
GPU: Nvidia RTX 2070 (8GB VRAM)

Could you provide a sample model that would replicate the error?

I had a error with torch._dynamo which is shared above, so I have less gains in the Flex Attention part with the torch.compile version as opposed to the author’s A100 gains in Flex Attention.