Apologies for the late response.. details below:
SDPA produces NaN with torch.compile mode=“max-autotune-no-cudagraphs” when batch_size × seq_len == 65536 (64k)
Summary
F.scaled_dot_product_attention produces NaN values when compiled with torch.compile(model, mode="max-autotune-no-cudagraphs", fullgraph=True) and the total number of tokens per batch (batch_size × seq_len) equals 65536 (64k).
This is a regression between NGC PyTorch containers 25.09-py3 (working) and 25.10-py3 / 25.11-py3 / 25.10-py3 which are all broken.
Environment
- Broken:
nvcr.io/nvidia/pytorch:25.12-py3, 25.11-py3, 25.10-py3
- Working:
nvcr.io/nvidia/pytorch:25.09-py3
- GPU: NVIDIA RTX PRO 6000 MaxQ 96GB
Threshold Testing
| batch_size |
seq_len |
tokens |
result |
| 31 |
2048 |
63488 |
OK |
| 32 |
2048 |
65536 |
NaN |
| 16 |
2048 |
32768 |
OK |
| 32 |
1024 |
32768 |
OK |
| 16 |
4096 |
65536 |
NaN |
| 8 |
8192 |
65536 |
NaN |
| 4 |
16384 |
65536 |
NaN |
| 2 |
32768 |
65536 |
NaN |
| 1 |
65536 |
65536 |
NaN |
We tested other combinations too, but the theme was always the same - a 64k token size was consistently producing NaNs.
Minimal Reproduction
This code is derived from Karpathy’s nanochat repository, reduced to isolate the NaN-producing behavior.
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
VOCAB = 65536
SEQ = 2048
BATCH = 32 # fails with 32 (65536 tokens), works for smaller sizes
N_LAYER = 20
N_EMBD = 1280
N_HEAD = 10
GRAD_ACCUM = 8
def norm(x):
return F.rms_norm(x, (x.size(-1),))
class Attn(nn.Module):
def __init__(self):
super().__init__()
self.qkv = nn.Linear(N_EMBD, 3 * N_EMBD, bias=False)
self.proj = nn.Linear(N_EMBD, N_EMBD, bias=False)
def forward(self, x):
B, T, _ = x.shape
qkv = self.qkv(x).view(B, T, 3, N_HEAD, N_EMBD // N_HEAD)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
q, k = norm(q), norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.proj(y.transpose(1, 2).contiguous().view(B, T, -1))
class Model(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(VOCAB, N_EMBD, dtype=torch.bfloat16)
self.layers = nn.ModuleList([Attn() for _ in range(N_LAYER)])
self.head = nn.Linear(N_EMBD, VOCAB, bias=False)
def forward(self, idx, tgt):
x = norm(self.wte(idx))
for layer in self.layers:
x = x + layer(norm(x))
logits = self.head(norm(x)).float()
logits = 15 * torch.tanh(logits / 15)
return F.cross_entropy(logits.view(-1, VOCAB), tgt.view(-1))
torch.manual_seed(0)
model = Model().cuda()
nn.init.zeros_(model.head.weight)
for layer in model.layers:
nn.init.zeros_(layer.proj.weight)
# leads to NaNs when batch * seq == 65536
model = torch.compile(model, dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
# works correctly
# model = torch.compile(model, dynamic=False, fullgraph=True)
# Note: @karpathy uses bf16 for AdamW, here we use fp32
opt = torch.optim.AdamW(model.parameters(), lr=0.1, fused=True)
g = torch.Generator(device="cuda").manual_seed(1234)
for step in range(8):
for _ in range(GRAD_ACCUM):
x = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g)
y = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g)
with torch.amp.autocast("cuda", torch.bfloat16): # unnecessary
loss = model(x, y)
if not math.isfinite(loss.item()):
print(f"FAIL: NaN at step {step}")
exit(1)
(loss / GRAD_ACCUM).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
model.zero_grad(set_to_none=True)
print(f"step {step}: {loss.item():.4f}")
print("OK")
Output
step 0: 11.0904
step 1: 11.0904
step 2: 11.0904
FAIL: NaN at step 3
Workarounds
Any of these prevent the issue:
- Use
torch.compile(model, dynamic=False, fullgraph=True) without mode="max-autotune-no-cudagraphs"
- Keep
batch_size × seq_len < 65536
- Use the older 25.09 container
And yes we tried the latest torch nightly, same result