Torch compile error when using flexattention + block mask

Hi,

I am trying to compile a model that uses flex attention and an attention mask, for this I created a simple code that replicates the error I am getting.

def _bidirectional_block_mask(attention_mask):
    def bidirectional_fn(b, h, q_idx, kv_idx):
      kv_idx = torch.clamp(kv_idx, 0, attention_mask.size(1) - 1)
      b = torch.clamp(b, 0, attention_mask.size(0) - 1)
      return attention_mask[b, kv_idx]

    return bidirectional_fn


class MyModel(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MyModel, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.query_projection = nn.Linear(input_dim, input_dim)
        self.key_projection = nn.Linear(input_dim, input_dim)
        self.value_projection = nn.Linear(input_dim, input_dim)
        self.output_projection = nn.Linear(input_dim, input_dim)
        
    def forward(self, input):
        batch_size, query_seq_len, _ = input.size()
        key_seq_len = query_seq_len

        attention_mask = torch.ones((batch_size, key_seq_len), dtype=torch.bool, device=input.device)

å
        bidirectional_block = create_block_mask(
            _bidirectional_block_mask(attention_mask), B=batch_size, H=self.num_heads, Q_LEN=query_seq_len, KV_LEN=query_seq_len, device=input.device
        )


        # Project inputs to query, key, and value
        queries = self.query_projection(input)
        keys = self.key_projection(input)
        values = self.value_projection(input)
        
        # Reshape queries, keys, and values for multi-head attention
        queries = queries.view(batch_size, query_seq_len, self.num_heads, -1).transpose(1, 2)
        keys = keys.view(batch_size, key_seq_len, self.num_heads, -1).transpose(1, 2)
        values = values.view(batch_size, key_seq_len, self.num_heads, -1).transpose(1, 2)

        attention_output = flex_attention(queries, keys, values, block_mask=bidirectional_block)
        # Reshape attention output
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, query_seq_len, -1)
        
        # Project attention output
        attention_output = self.output_projection(attention_output)
        
        return attention_output

The error I am getting is the following:

---------------------------------------------------------------------------
InternalTorchDynamoError                  Traceback (most recent call last)
<ipython-input-5-454293cadcc3> in <cell line: 4>()
      2 model = MyModel(32, 1).cuda()
      3 model = torch.compile(model)
----> 4 output = model(input)
      5 output

45 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
    463 
    464             try:
--> 465                 return fn(*args, **kwargs)
    466             finally:
    467                 # Restore the dynamic layer stack depth if necessary.

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in __call__(self, frame, cache_entry, frame_state)
   1267         with compile_lock, _disable_current_modes():
   1268             # skip=1: skip this frame
-> 1269             return self._torchdynamo_orig_callable(
   1270                 frame, cache_entry, self.hooks, frame_state, skip=1
   1271             )

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in __call__(self, frame, cache_entry, hooks, frame_state, skip)
   1062         counters["frames"]["total"] += 1
   1063         try:
-> 1064             result = self._inner_convert(
   1065                 frame, cache_entry, hooks, frame_state, skip=skip + 1
   1066             )

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in __call__(self, frame, cache_entry, hooks, frame_state, skip)
    524         )
    525 
--> 526         return _compile(
    527             frame.f_code,
    528             frame.f_globals,

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
    950             else:
    951                 # Rewrap for clarity
--> 952                 raise InternalTorchDynamoError(
    953                     f"{type(e).__qualname__}: {str(e)}"
    954                 ).with_traceback(e.__traceback__) from None

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
    922         guarded_code = None
    923         try:
--> 924             guarded_code = compile_inner(code, one_graph, hooks, transform)
    925             return guarded_code
    926         except Exception as e:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in compile_inner(code, one_graph, hooks, transform)
    664         with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
    665             with CompileTimeInstructionCounter.record():
--> 666                 return _compile_inner(code, one_graph, hooks, transform)
    667 
    668     @compile_time_strobelight_meta(phase_name="compile_inner")

/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py in wrapper_function(*args, **kwargs)
     85 
     86             if not StrobelightCompileTimeProfiler.enabled:
---> 87                 return function(*args, **kwargs)
     88 
     89             return StrobelightCompileTimeProfiler.profile_compile_time(

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _compile_inner(code, one_graph, hooks, transform)
    697             CompileContext.get().attempt = attempt
    698             try:
--> 699                 out_code = transform_code_object(code, transform)
    700                 break
    701             except exc.RestartAnalysis as e:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py in transform_code_object(code, transformations, safe)
   1320     propagate_line_nums(instructions)
   1321 
-> 1322     transformations(instructions, code_options)
   1323     return clean_and_assemble_instructions(instructions, keys, code_options)[1]
   1324 

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _fn(*args, **kwargs)
    217             )
    218             try:
--> 219                 return fn(*args, **kwargs)
    220             finally:
    221                 cleanup.close()

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in transform(instructions, code_options)
    632         try:
    633             with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 634                 tracer.run()
    635         except exc.UnspecializeRestartAnalysis:
    636             speculation_log.clear()

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
   2794 
   2795     def run(self):
-> 2796         super().run()
   2797 
   2798     def match_nested_cell(self, name, cell):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
    981             try:
    982                 self.output.push_tx(self)
--> 983                 while self.step():
    984                     pass
    985             except BackendCompilerFailed:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in step(self)
    893 
    894         try:
--> 895             self.dispatch_table[inst.opcode](self, inst)
    896             return not self.output.should_exit
    897         except exc.ObservedException as e:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in wrapper(self, inst)
    580                 return handle_graph_break(self, inst, speculation.reason)
    581             try:
--> 582                 return inner_fn(self, inst)
    583             except Unsupported as excp:
    584                 if self.generic_context_manager_depth > 0:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in CALL_FUNCTION_KW(self, inst)
   1690         kwargs = dict(zip(argnames, kwargs_list))
   1691         assert len(kwargs) == len(argnames)
-> 1692         self.call_function(fn, args, kwargs)
   1693 
   1694     def LOAD_METHOD_SUPER(self, inst):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in call_function(self, fn, args, kwargs)
    828         if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    829             raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830         self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
    831 
    832     def inline_user_function_return(self, fn, args, kwargs):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
    322             )
    323 
--> 324         return super().call_function(tx, args, kwargs)
    325 
    326 

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
    109         kwargs: "Dict[str, VariableTracker]",
    110     ) -> "VariableTracker":
--> 111         return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
    112 
    113     def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_user_function_return(self, fn, args, kwargs)
    834         A call to some user defined function by inlining it.
    835         """
--> 836         return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
    837 
    838     def get_line_of_code_header(self, lineno=None):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call(cls, parent, func, args, kwargs)
   3009     def inline_call(cls, parent, func, args, kwargs):
   3010         with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011             return cls.inline_call_(parent, func, args, kwargs)
   3012 
   3013     @staticmethod

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call_(parent, func, args, kwargs)
   3137         try:
   3138             with strict_ctx:
-> 3139                 tracer.run()
   3140         except exc.ObservedException as e:
   3141             msg = f"Observed exception DURING INLING {code} : {e}"

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
    981             try:
    982                 self.output.push_tx(self)
--> 983                 while self.step():
    984                     pass
    985             except BackendCompilerFailed:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in step(self)
    893 
    894         try:
--> 895             self.dispatch_table[inst.opcode](self, inst)
    896             return not self.output.should_exit
    897         except exc.ObservedException as e:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in wrapper(self, inst)
    580                 return handle_graph_break(self, inst, speculation.reason)
    581             try:
--> 582                 return inner_fn(self, inst)
    583             except Unsupported as excp:
    584                 if self.generic_context_manager_depth > 0:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in CALL_FUNCTION(self, inst)
   1600         args = self.popn(inst.argval)
   1601         fn = self.pop()
-> 1602         self.call_function(fn, args, {})
   1603 
   1604     @break_graph_if_unsupported(push=1)

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in call_function(self, fn, args, kwargs)
    828         if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    829             raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830         self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
    831 
    832     def inline_user_function_return(self, fn, args, kwargs):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
    322             )
    323 
--> 324         return super().call_function(tx, args, kwargs)
    325 
    326 

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
    109         kwargs: "Dict[str, VariableTracker]",
    110     ) -> "VariableTracker":
--> 111         return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
    112 
    113     def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_user_function_return(self, fn, args, kwargs)
    834         A call to some user defined function by inlining it.
    835         """
--> 836         return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
    837 
    838     def get_line_of_code_header(self, lineno=None):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call(cls, parent, func, args, kwargs)
   3009     def inline_call(cls, parent, func, args, kwargs):
   3010         with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011             return cls.inline_call_(parent, func, args, kwargs)
   3012 
   3013     @staticmethod

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call_(parent, func, args, kwargs)
   3137         try:
   3138             with strict_ctx:
-> 3139                 tracer.run()
   3140         except exc.ObservedException as e:
   3141             msg = f"Observed exception DURING INLING {code} : {e}"

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
    981             try:
    982                 self.output.push_tx(self)
--> 983                 while self.step():
    984                     pass
    985             except BackendCompilerFailed:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in step(self)
    893 
    894         try:
--> 895             self.dispatch_table[inst.opcode](self, inst)
    896             return not self.output.should_exit
    897         except exc.ObservedException as e:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in wrapper(self, inst)
    580                 return handle_graph_break(self, inst, speculation.reason)
    581             try:
--> 582                 return inner_fn(self, inst)
    583             except Unsupported as excp:
    584                 if self.generic_context_manager_depth > 0:

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in CALL_FUNCTION(self, inst)
   1600         args = self.popn(inst.argval)
   1601         fn = self.pop()
-> 1602         self.call_function(fn, args, {})
   1603 
   1604     @break_graph_if_unsupported(push=1)

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in call_function(self, fn, args, kwargs)
    828         if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    829             raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830         self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
    831 
    832     def inline_user_function_return(self, fn, args, kwargs):

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py in call_function(self, tx, args, kwargs)
    975         kwargs: "Dict[str, VariableTracker]",
    976     ) -> "VariableTracker":
--> 977         return self.fn(*args, **kwargs)
    978 
    979 

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py in create(callable, **kwargs)
    386         if kwargs:
    387             unimplemented(f"inspect.signature with {kwargs}")
--> 388         return InspectSignatureVariable(
    389             callable, mutable_local=variables.base.MutableLocal()
    390         )

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py in __init__(self, inspected, **kwargs)
    403             self.parameters = list(self.signature.parameters.items())
    404         else:
--> 405             self.fn = self.inspected.as_python_constant()
    406             self.signature = inspect.signature(self.fn)
    407             self.parameters = list(self.signature.parameters.items())

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py in as_python_constant(self)
    215     def as_python_constant(self):
    216         """For constants"""
--> 217         raise NotImplementedError(f"{self} is not a constant")
    218 
    219     def guard_as_python_constant(self):

InternalTorchDynamoError: NotImplementedError: NestedUserFunctionVariable() is not a constant

from user code:
   File "<ipython-input-4-42bb243baef1>", line 26, in forward
    bidirectional_block = create_block_mask(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 827, in create_block_mask
    mod_type = _get_mod_type(mask_mod)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 63, in _get_mod_type
    for param in inspect.signature(fn).parameters.values()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Am I doing something wrong? Is using block mask the recommended way to perform attention masking?