Torch.compile + flex_attention with custom score_mod fails (FlexibleLayout / NoValidChoicesError)

(I also made a pytorch github post here)

:bug: Describe the bug

Compiling a module that wraps torch.nn.attention.flex_attention with a custom score_mod fails under torch.compile (Inductor).

  • Eager execution works correctly.
  • With torch.compile(dynamic=False) I get:

BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError: convert FlexibleLayout to FixedLayout first

  • With torch.compile(dynamic=False, mode="max-autotune-no-cudagraphs") I get:

BackendCompilerFailed: backend='inductor' raised:
LoweringException: NoValidChoicesError

  • Tried with both bfloat16 (autocast) and float16 — same errors.

Minimal repro

import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention import SDPBackend

class FlexAttentionCPB(nn.Module):
    def __init__(self, N: int, R: int, H: int = 6, hidden: int = 32):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(2, hidden), nn.GELU(), nn.Linear(hidden, H, bias=False))
        self.gamma = nn.Parameter(torch.zeros(H))
        self.H = H
        self.init_tables(N, R)
        self.register_buffer("r_cutoff", torch.tensor(R, dtype=torch.long), persistent=False)

    def init_tables(self, N: int, R: int):
        # continuous position bias –  SwinV2
        P = N - R
        S = int(P**0.5)
        assert S * S == P
        rng = torch.arange(-(S - 1), S, dtype=torch.float32)
        dY, dX = torch.meshgrid(rng, rng, indexing="ij")
        rel = torch.stack([dY / max(S - 1, 1), dX / max(S - 1, 1)], dim=-1).reshape(-1, 2)
        rel_table = torch.sign(rel) * torch.log1p(rel.abs())
        self.register_buffer("rel_table", rel_table, persistent=False)

        yy, xx = torch.arange(S), torch.arange(S)
        Y, X = torch.meshgrid(yy, xx, indexing="ij")
        flat = torch.stack([Y, X], 0).flatten(1)
        d = flat[:, :, None] - flat[:, None, :]
        d = d.permute(1, 2, 0).contiguous()
        d[:, :, 0] += S - 1; d[:, :, 1] += S - 1
        d[:, :, 0] *= 2 * S - 1
        l_idx = d.sum(-1).to(torch.long)

        idx = torch.full((N, N), 0, dtype=torch.long)
        idx[R:, R:] = l_idx
        self.register_buffer("idx_table", idx, persistent=False)

    def _score_mod(self, mu: torch.Tensor):
        bt = self.mlp(self.rel_table)
        idx = self.idx_table
        mu_q, mu_k = mu.unbind(2)
        gam_sig = torch.sigmoid(self.gamma)

        def score_mod(score, b, h, q, kv):
            has_bias = (q >= self.r_cutoff) & (kv >= self.r_cutoff)
            l2 = idx[q, kv]
            bias = bt[l2, h]
            w_gate = gam_sig[h] * (mu_q[b, h, q] + mu_k[b, h, kv])
            return score + has_bias.to(score.dtype) * w_gate * bias

        return score_mod

    def forward(self, q, k, v, mu):
        return flex_attention(q, k, v, score_mod=self._score_mod(mu))

def main():
    device = "cuda"
    B, N, R, d, H = 2, 18, 2, 32, 4
    mod = FlexAttentionCPB(N, R, H).to(device)
    mod = torch.compile(mod, dynamic=False)  # also tested "max-autotune-no-cudagraphs"

    q = torch.randn(B, H, N, d, device=device)
    k = torch.randn_like(q)
    v = torch.randn_like(q)
    mu = torch.randn(B, H, 2, N, device=device)

    with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            out = mod(q, k, v, mu)
            out.norm().backward()
    print("done")

if __name__ == "__main__":
    main()

Ablation

  • Eager (no torch.compile) → :white_check_mark: works
  • torch.compile(dynamic=False):cross_mark: FlexibleLayout→FixedLayout assertion
  • torch.compile(dynamic=False, mode="max-autotune-no-cudagraphs"):cross_mark: NoValidChoicesError
  • Backends: tested with SDPBackend.FLASH_ATTENTION.
  • Dtypes: bf16 autocast and fp16 both fail under compile.

Environment

python: 3.10.12 (GCC 11.4.0)
cuda: 12.4
cudnn: 90100
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124
torchaudio: 2.6.0+cu124
triton: 3.2.0
GPU: NVIDIA A100-SXM4-80GB (sm 8.0)

(Full python -m torch.utils.collect_env output attached.)

Expected behavior

FlexAttention with a custom score modifier should compile successfully, or at least provide a clear error if the pattern/layout is unsupported.

Actual behavior

Inductor compilation fails with LoweringException under multiple modes.

Additional context

  • This looks specific to FlexAttention + custom score_mod.

  • Happy to retest on nightly if maintainers think this may already be addressed.

  • Attached logs:

    • compile_default_log.txt
    • compile_autotune_no_cudagraphs_log.txt
    • torch_collect_env.txt
    • nvidia_smi.xml
    • pip_freeze.txt

(full list on github issues)

Error logs

see attached –

Versions

see attached –

[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0+cu124
[pip3] torchaudio==2.6.0+cu124
[pip3] torchvision==0.21.0+cu124
[pip3] triton==3.2.0