Flagging function with torch._dynamo.disallow_in_graph still causes graph break error

I have a following code

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = triton_flash_attn_fn

    def forward(self, x):
        return self.attn(x, x, x, n_heads=8)


model = TestModel()
# create dummy input
x = torch.randn(2, 1024, 1024).cuda().bfloat16()
torch._dynamo.disallow_in_graph(triton_flash_attn_fn)
model = torch.compile(model)
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type='cuda'):
    # run once to compile
    r = model(x)
    print(r)

It runs fine without torch.compile, but with torch.compile the model forward fails with graph break in triton_flash_attn_fn. So I decided to test out torch._dynamo.disallow_in_graph, in order to exclude this function from the compile graph, but code still fails the same way.

Q: Shouldn’t disallow_in_graph prevent torch from trying to include function during compile graph construction?

… From the docs

torch._dynamo.disallow_in_graph(*fn* )[[source]](https://pytorch.org/docs/stable/_modules/torch/_dynamo.html#disallow_in_graph)

Customize which functions TorchDynamo will exclude in the generated graph and force a graph break on.

Environment

blis==0.7.9
catalogue==2.0.8
cmake==3.27.6
confection==0.0.4
cymem==2.0.7
einops==0.7.0
filelock==3.12.4
Jinja2==3.1.2
langcodes==3.3.0
lit==17.0.2
MarkupSafe==2.1.2
mpmath==1.3.0
murmurhash==1.0.9
networkx==3.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
pathy==0.10.1
preshed==3.0.8
smart-open==6.3.0
srsly==2.4.5
sympy==1.12
torch==2.0.1
triton==2.0.0
triton-pre-mlir @ git+https://github.com/vchiley/triton.git@86c7fe23397467ade531513291f729c12dd8d15e#subdirectory=python
typer==0.7.0
typing_extensions==4.8.0
wasabi==0.10.1

Here is the error:

myenvpath/bin/python myscriptpath/t_triton_compile.py 
Traceback (most recent call last):
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 148, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/variables/builtin.py", line 566, in call_function
    result = handler(tx, *args, **kwargs)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/variables/builtin.py", line 791, in call_getitem
    return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/variables/dicts.py", line 66, in call_method
    return self.getitem_const(args[0])
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/variables/dicts.py", line 51, in getitem_const
    return self.items[ConstDictVariable.get_key(arg)].add_options(self, arg)
KeyError: ('2-.-0-.-0--394352f6a8351feaac334fbb8cc63fa4-46c7c5d46afed8316facd72e7e581bec-eabceb538e9d90c485174c2451fb9ee3-39e3c68a052760cc345a9147b0d68f7d-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-4ac47e74762ba6a774cceea0e1e75ae6-13b7ffc189bd9fba7696034bbcfee151', (torch.bfloat16, torch.bfloat16, torch.bfloat16, None, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('none', False, 128, True, True, True, 128, 128), (True, True, True, (False,), True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "myscriptpath/t_triton_compile.py", line 855, in <module>
    r = model(x)
  File "myenvpath/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "myscriptpath/t_triton_compile.py", line 845, in forward
    return self.attn(x, x, x, n_heads=8)
  File "myscriptpath/t_triton_compile.py", line 825, in triton_flash_attn_fn
    query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
  File "myscriptpath/t_triton_compile.py", line 826, in <graph break in triton_flash_attn_fn>
    key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
  File "myscriptpath/t_triton_compile.py", line 827, in <graph break in triton_flash_attn_fn>
    value = rearrange(value,
  File "myscriptpath/t_triton_compile.py", line 831, in <graph break in triton_flash_attn_fn>
    attn_output = flash_attn_func(  # type: ignore
  File "myenvpath/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "myscriptpath/t_triton_compile.py", line 788, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "myscriptpath/t_triton_compile.py", line 601, in _flash_attn_forward
    _fwd_kernel[grid](
  File "myenvpath/lib/python3.9/site-packages/triton_pre_mlir/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "myenvpath/lib/python3.9/site-packages/triton_pre_mlir/runtime/autotuner.py", line 200, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 3, in _fwd_kernel
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 3, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 5, in <graph break in _fwd_kernel>
  File "<string>", line 10, in <graph break in _fwd_kernel>
  File "<string>", line 16, in <graph break in _fwd_kernel>
  File "<string>", line 16, in <graph break in _fwd_kernel>
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 393, in _compile
    exception_handler(e, code, frame)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 184, in exception_handler
    augment_exc_message(e)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/exc.py", line 127, in augment_exc_message
    exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
TypeError: can only concatenate tuple (not "str") to tuple

At last, here is the implementation of triton_flash_attn_fn function

So I tried torch 2.1 and this is the error

Traceback (most recent call last):
  File "myscriptpath/t_triton_compile.py", line 853, in <module>
    torch._dynamo.disallow_in_graph(model.attn)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/decorators.py", line 138, in disallow_in_graph
    return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn)
  File "myenvpath/lib/python3.9/site-packages/torch/_dynamo/decorators.py", line 106, in inner
    raise IncorrectUsage(
torch._dynamo.exc.IncorrectUsage: disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). Allowed callables means callables that TorchDynamo puts as-is in the extracted graph.

So if I understand correctly, I can only disallow fn in graph on torch ops? (Or ops which torch.dynamo does not needs to “compile” into graph structure, because they already are?)