Hi,
I am trying to compile a model that uses flex attention and an attention mask, for this I created a simple code that replicates the error I am getting.
def _bidirectional_block_mask(attention_mask):
def bidirectional_fn(b, h, q_idx, kv_idx):
kv_idx = torch.clamp(kv_idx, 0, attention_mask.size(1) - 1)
b = torch.clamp(b, 0, attention_mask.size(0) - 1)
return attention_mask[b, kv_idx]
return bidirectional_fn
class MyModel(nn.Module):
def __init__(self, input_dim, num_heads):
super(MyModel, self).__init__()
self.input_dim = input_dim
self.num_heads = num_heads
self.query_projection = nn.Linear(input_dim, input_dim)
self.key_projection = nn.Linear(input_dim, input_dim)
self.value_projection = nn.Linear(input_dim, input_dim)
self.output_projection = nn.Linear(input_dim, input_dim)
def forward(self, input):
batch_size, query_seq_len, _ = input.size()
key_seq_len = query_seq_len
attention_mask = torch.ones((batch_size, key_seq_len), dtype=torch.bool, device=input.device)
å
bidirectional_block = create_block_mask(
_bidirectional_block_mask(attention_mask), B=batch_size, H=self.num_heads, Q_LEN=query_seq_len, KV_LEN=query_seq_len, device=input.device
)
# Project inputs to query, key, and value
queries = self.query_projection(input)
keys = self.key_projection(input)
values = self.value_projection(input)
# Reshape queries, keys, and values for multi-head attention
queries = queries.view(batch_size, query_seq_len, self.num_heads, -1).transpose(1, 2)
keys = keys.view(batch_size, key_seq_len, self.num_heads, -1).transpose(1, 2)
values = values.view(batch_size, key_seq_len, self.num_heads, -1).transpose(1, 2)
attention_output = flex_attention(queries, keys, values, block_mask=bidirectional_block)
# Reshape attention output
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, query_seq_len, -1)
# Project attention output
attention_output = self.output_projection(attention_output)
return attention_output
The error I am getting is the following:
---------------------------------------------------------------------------
InternalTorchDynamoError Traceback (most recent call last)
<ipython-input-5-454293cadcc3> in <cell line: 4>()
2 model = MyModel(32, 1).cuda()
3 model = torch.compile(model)
----> 4 output = model(input)
5 output
45 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
463
464 try:
--> 465 return fn(*args, **kwargs)
466 finally:
467 # Restore the dynamic layer stack depth if necessary.
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in __call__(self, frame, cache_entry, frame_state)
1267 with compile_lock, _disable_current_modes():
1268 # skip=1: skip this frame
-> 1269 return self._torchdynamo_orig_callable(
1270 frame, cache_entry, self.hooks, frame_state, skip=1
1271 )
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in __call__(self, frame, cache_entry, hooks, frame_state, skip)
1062 counters["frames"]["total"] += 1
1063 try:
-> 1064 result = self._inner_convert(
1065 frame, cache_entry, hooks, frame_state, skip=skip + 1
1066 )
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in __call__(self, frame, cache_entry, hooks, frame_state, skip)
524 )
525
--> 526 return _compile(
527 frame.f_code,
528 frame.f_globals,
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
950 else:
951 # Rewrap for clarity
--> 952 raise InternalTorchDynamoError(
953 f"{type(e).__qualname__}: {str(e)}"
954 ).with_traceback(e.__traceback__) from None
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
922 guarded_code = None
923 try:
--> 924 guarded_code = compile_inner(code, one_graph, hooks, transform)
925 return guarded_code
926 except Exception as e:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in compile_inner(code, one_graph, hooks, transform)
664 with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
665 with CompileTimeInstructionCounter.record():
--> 666 return _compile_inner(code, one_graph, hooks, transform)
667
668 @compile_time_strobelight_meta(phase_name="compile_inner")
/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py in wrapper_function(*args, **kwargs)
85
86 if not StrobelightCompileTimeProfiler.enabled:
---> 87 return function(*args, **kwargs)
88
89 return StrobelightCompileTimeProfiler.profile_compile_time(
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _compile_inner(code, one_graph, hooks, transform)
697 CompileContext.get().attempt = attempt
698 try:
--> 699 out_code = transform_code_object(code, transform)
700 break
701 except exc.RestartAnalysis as e:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py in transform_code_object(code, transformations, safe)
1320 propagate_line_nums(instructions)
1321
-> 1322 transformations(instructions, code_options)
1323 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
1324
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in _fn(*args, **kwargs)
217 )
218 try:
--> 219 return fn(*args, **kwargs)
220 finally:
221 cleanup.close()
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py in transform(instructions, code_options)
632 try:
633 with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 634 tracer.run()
635 except exc.UnspecializeRestartAnalysis:
636 speculation_log.clear()
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
2794
2795 def run(self):
-> 2796 super().run()
2797
2798 def match_nested_cell(self, name, cell):
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
981 try:
982 self.output.push_tx(self)
--> 983 while self.step():
984 pass
985 except BackendCompilerFailed:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in step(self)
893
894 try:
--> 895 self.dispatch_table[inst.opcode](self, inst)
896 return not self.output.should_exit
897 except exc.ObservedException as e:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in wrapper(self, inst)
580 return handle_graph_break(self, inst, speculation.reason)
581 try:
--> 582 return inner_fn(self, inst)
583 except Unsupported as excp:
584 if self.generic_context_manager_depth > 0:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in CALL_FUNCTION_KW(self, inst)
1690 kwargs = dict(zip(argnames, kwargs_list))
1691 assert len(kwargs) == len(argnames)
-> 1692 self.call_function(fn, args, kwargs)
1693
1694 def LOAD_METHOD_SUPER(self, inst):
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in call_function(self, fn, args, kwargs)
828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
829 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
831
832 def inline_user_function_return(self, fn, args, kwargs):
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
322 )
323
--> 324 return super().call_function(tx, args, kwargs)
325
326
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
109 kwargs: "Dict[str, VariableTracker]",
110 ) -> "VariableTracker":
--> 111 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
112
113 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_user_function_return(self, fn, args, kwargs)
834 A call to some user defined function by inlining it.
835 """
--> 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
837
838 def get_line_of_code_header(self, lineno=None):
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call(cls, parent, func, args, kwargs)
3009 def inline_call(cls, parent, func, args, kwargs):
3010 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011 return cls.inline_call_(parent, func, args, kwargs)
3012
3013 @staticmethod
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call_(parent, func, args, kwargs)
3137 try:
3138 with strict_ctx:
-> 3139 tracer.run()
3140 except exc.ObservedException as e:
3141 msg = f"Observed exception DURING INLING {code} : {e}"
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
981 try:
982 self.output.push_tx(self)
--> 983 while self.step():
984 pass
985 except BackendCompilerFailed:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in step(self)
893
894 try:
--> 895 self.dispatch_table[inst.opcode](self, inst)
896 return not self.output.should_exit
897 except exc.ObservedException as e:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in wrapper(self, inst)
580 return handle_graph_break(self, inst, speculation.reason)
581 try:
--> 582 return inner_fn(self, inst)
583 except Unsupported as excp:
584 if self.generic_context_manager_depth > 0:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in CALL_FUNCTION(self, inst)
1600 args = self.popn(inst.argval)
1601 fn = self.pop()
-> 1602 self.call_function(fn, args, {})
1603
1604 @break_graph_if_unsupported(push=1)
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in call_function(self, fn, args, kwargs)
828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
829 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
831
832 def inline_user_function_return(self, fn, args, kwargs):
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
322 )
323
--> 324 return super().call_function(tx, args, kwargs)
325
326
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py in call_function(self, tx, args, kwargs)
109 kwargs: "Dict[str, VariableTracker]",
110 ) -> "VariableTracker":
--> 111 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
112
113 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_user_function_return(self, fn, args, kwargs)
834 A call to some user defined function by inlining it.
835 """
--> 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
837
838 def get_line_of_code_header(self, lineno=None):
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call(cls, parent, func, args, kwargs)
3009 def inline_call(cls, parent, func, args, kwargs):
3010 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011 return cls.inline_call_(parent, func, args, kwargs)
3012
3013 @staticmethod
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in inline_call_(parent, func, args, kwargs)
3137 try:
3138 with strict_ctx:
-> 3139 tracer.run()
3140 except exc.ObservedException as e:
3141 msg = f"Observed exception DURING INLING {code} : {e}"
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in run(self)
981 try:
982 self.output.push_tx(self)
--> 983 while self.step():
984 pass
985 except BackendCompilerFailed:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in step(self)
893
894 try:
--> 895 self.dispatch_table[inst.opcode](self, inst)
896 return not self.output.should_exit
897 except exc.ObservedException as e:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in wrapper(self, inst)
580 return handle_graph_break(self, inst, speculation.reason)
581 try:
--> 582 return inner_fn(self, inst)
583 except Unsupported as excp:
584 if self.generic_context_manager_depth > 0:
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in CALL_FUNCTION(self, inst)
1600 args = self.popn(inst.argval)
1601 fn = self.pop()
-> 1602 self.call_function(fn, args, {})
1603
1604 @break_graph_if_unsupported(push=1)
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py in call_function(self, fn, args, kwargs)
828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
829 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
831
832 def inline_user_function_return(self, fn, args, kwargs):
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py in call_function(self, tx, args, kwargs)
975 kwargs: "Dict[str, VariableTracker]",
976 ) -> "VariableTracker":
--> 977 return self.fn(*args, **kwargs)
978
979
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py in create(callable, **kwargs)
386 if kwargs:
387 unimplemented(f"inspect.signature with {kwargs}")
--> 388 return InspectSignatureVariable(
389 callable, mutable_local=variables.base.MutableLocal()
390 )
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py in __init__(self, inspected, **kwargs)
403 self.parameters = list(self.signature.parameters.items())
404 else:
--> 405 self.fn = self.inspected.as_python_constant()
406 self.signature = inspect.signature(self.fn)
407 self.parameters = list(self.signature.parameters.items())
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py in as_python_constant(self)
215 def as_python_constant(self):
216 """For constants"""
--> 217 raise NotImplementedError(f"{self} is not a constant")
218
219 def guard_as_python_constant(self):
InternalTorchDynamoError: NotImplementedError: NestedUserFunctionVariable() is not a constant
from user code:
File "<ipython-input-4-42bb243baef1>", line 26, in forward
bidirectional_block = create_block_mask(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 827, in create_block_mask
mod_type = _get_mod_type(mask_mod)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 63, in _get_mod_type
for param in inspect.signature(fn).parameters.values()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Am I doing something wrong? Is using block mask the recommended way to perform attention masking?