(I also made a pytorch github post here)
Describe the bug
Compiling a module that wraps torch.nn.attention.flex_attention with a custom score_mod fails under torch.compile (Inductor).
- Eager execution works correctly.
- With
torch.compile(dynamic=False)I get:
BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError: convert FlexibleLayout to FixedLayout first
- With
torch.compile(dynamic=False, mode="max-autotune-no-cudagraphs")I get:
BackendCompilerFailed: backend='inductor' raised:
LoweringException: NoValidChoicesError
- Tried with both
bfloat16(autocast) andfloat16— same errors.
Minimal repro
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention import SDPBackend
class FlexAttentionCPB(nn.Module):
def __init__(self, N: int, R: int, H: int = 6, hidden: int = 32):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(2, hidden), nn.GELU(), nn.Linear(hidden, H, bias=False))
self.gamma = nn.Parameter(torch.zeros(H))
self.H = H
self.init_tables(N, R)
self.register_buffer("r_cutoff", torch.tensor(R, dtype=torch.long), persistent=False)
def init_tables(self, N: int, R: int):
# continuous position bias – SwinV2
P = N - R
S = int(P**0.5)
assert S * S == P
rng = torch.arange(-(S - 1), S, dtype=torch.float32)
dY, dX = torch.meshgrid(rng, rng, indexing="ij")
rel = torch.stack([dY / max(S - 1, 1), dX / max(S - 1, 1)], dim=-1).reshape(-1, 2)
rel_table = torch.sign(rel) * torch.log1p(rel.abs())
self.register_buffer("rel_table", rel_table, persistent=False)
yy, xx = torch.arange(S), torch.arange(S)
Y, X = torch.meshgrid(yy, xx, indexing="ij")
flat = torch.stack([Y, X], 0).flatten(1)
d = flat[:, :, None] - flat[:, None, :]
d = d.permute(1, 2, 0).contiguous()
d[:, :, 0] += S - 1; d[:, :, 1] += S - 1
d[:, :, 0] *= 2 * S - 1
l_idx = d.sum(-1).to(torch.long)
idx = torch.full((N, N), 0, dtype=torch.long)
idx[R:, R:] = l_idx
self.register_buffer("idx_table", idx, persistent=False)
def _score_mod(self, mu: torch.Tensor):
bt = self.mlp(self.rel_table)
idx = self.idx_table
mu_q, mu_k = mu.unbind(2)
gam_sig = torch.sigmoid(self.gamma)
def score_mod(score, b, h, q, kv):
has_bias = (q >= self.r_cutoff) & (kv >= self.r_cutoff)
l2 = idx[q, kv]
bias = bt[l2, h]
w_gate = gam_sig[h] * (mu_q[b, h, q] + mu_k[b, h, kv])
return score + has_bias.to(score.dtype) * w_gate * bias
return score_mod
def forward(self, q, k, v, mu):
return flex_attention(q, k, v, score_mod=self._score_mod(mu))
def main():
device = "cuda"
B, N, R, d, H = 2, 18, 2, 32, 4
mod = FlexAttentionCPB(N, R, H).to(device)
mod = torch.compile(mod, dynamic=False) # also tested "max-autotune-no-cudagraphs"
q = torch.randn(B, H, N, d, device=device)
k = torch.randn_like(q)
v = torch.randn_like(q)
mu = torch.randn(B, H, 2, N, device=device)
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
out = mod(q, k, v, mu)
out.norm().backward()
print("done")
if __name__ == "__main__":
main()
Ablation
- Eager (no
torch.compile) →
works torch.compile(dynamic=False)→
FlexibleLayout→FixedLayout assertiontorch.compile(dynamic=False, mode="max-autotune-no-cudagraphs")→
NoValidChoicesError- Backends: tested with
SDPBackend.FLASH_ATTENTION. - Dtypes: bf16 autocast and fp16 both fail under compile.
Environment
python: 3.10.12 (GCC 11.4.0)
cuda: 12.4
cudnn: 90100
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124
torchaudio: 2.6.0+cu124
triton: 3.2.0
GPU: NVIDIA A100-SXM4-80GB (sm 8.0)
(Full python -m torch.utils.collect_env output attached.)
Expected behavior
FlexAttention with a custom score modifier should compile successfully, or at least provide a clear error if the pattern/layout is unsupported.
Actual behavior
Inductor compilation fails with LoweringException under multiple modes.
Additional context
-
This looks specific to FlexAttention + custom
score_mod. -
Happy to retest on nightly if maintainers think this may already be addressed.
-
Attached logs:
compile_default_log.txtcompile_autotune_no_cudagraphs_log.txttorch_collect_env.txtnvidia_smi.xmlpip_freeze.txt
(full list on github issues)
Error logs
see attached –
Versions
see attached –
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0+cu124
[pip3] torchaudio==2.6.0+cu124
[pip3] torchvision==0.21.0+cu124
[pip3] triton==3.2.0