Torch.compile - getting `nans` with the latest version of torch

When running our training code on 4x RTX Blackwell 6000 cards using: nvcr.io/nvidia/pytorch:25.09-py3, our training is stable, but when upgrading to nvcr.io/nvidia/pytorch:25.10-py3 our loss quickly goes to NaN using the same exact training configuration. We upgraded torch/triton to the nightlies, and are still getting NaNs.

Training is (however) stable if we reduce the batch-size by half.

Presumably the later versions of torch/triton are choosing buggy kernels, but how do we open a ticket so someone can replicate this? This seems hardware dependent? Is there a way for us to upload the inductor cache files?

Thanks!

Could you post a minimal and executable code snippet reproducing this issue?

Apologies for the late response.. details below:

SDPA produces NaN with torch.compile mode=“max-autotune-no-cudagraphs” when batch_size × seq_len == 65536 (64k)

Summary

F.scaled_dot_product_attention produces NaN values when compiled with torch.compile(model, mode="max-autotune-no-cudagraphs", fullgraph=True) and the total number of tokens per batch (batch_size × seq_len) equals 65536 (64k).

This is a regression between NGC PyTorch containers 25.09-py3 (working) and 25.10-py3 / 25.11-py3 / 25.10-py3 which are all broken.

Environment

  • Broken: nvcr.io/nvidia/pytorch:25.12-py3, 25.11-py3, 25.10-py3
  • Working: nvcr.io/nvidia/pytorch:25.09-py3
  • GPU: NVIDIA RTX PRO 6000 MaxQ 96GB

Threshold Testing

batch_size seq_len tokens result
31 2048 63488 :white_check_mark: OK
32 2048 65536 :cross_mark: NaN
16 2048 32768 :white_check_mark: OK
32 1024 32768 :white_check_mark: OK
16 4096 65536 :cross_mark: NaN
8 8192 65536 :cross_mark: NaN
4 16384 65536 :cross_mark: NaN
2 32768 65536 :cross_mark: NaN
1 65536 65536 :cross_mark: NaN

We tested other combinations too, but the theme was always the same - a 64k token size was consistently producing NaNs.

Minimal Reproduction

This code is derived from Karpathy’s nanochat repository, reduced to isolate the NaN-producing behavior.

import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

VOCAB = 65536
SEQ = 2048
BATCH = 32  # fails with 32 (65536 tokens), works for smaller sizes
N_LAYER = 20
N_EMBD = 1280
N_HEAD = 10
GRAD_ACCUM = 8


def norm(x):
    return F.rms_norm(x, (x.size(-1),))


class Attn(nn.Module):
    def __init__(self):
        super().__init__()
        self.qkv = nn.Linear(N_EMBD, 3 * N_EMBD, bias=False)
        self.proj = nn.Linear(N_EMBD, N_EMBD, bias=False)

    def forward(self, x):
        B, T, _ = x.shape
        qkv = self.qkv(x).view(B, T, 3, N_HEAD, N_EMBD // N_HEAD)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
        q, k = norm(q), norm(k)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.proj(y.transpose(1, 2).contiguous().view(B, T, -1))


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.wte = nn.Embedding(VOCAB, N_EMBD, dtype=torch.bfloat16)
        self.layers = nn.ModuleList([Attn() for _ in range(N_LAYER)])
        self.head = nn.Linear(N_EMBD, VOCAB, bias=False)

    def forward(self, idx, tgt):
        x = norm(self.wte(idx))
        for layer in self.layers:
            x = x + layer(norm(x))
        logits = self.head(norm(x)).float()
        logits = 15 * torch.tanh(logits / 15)
        return F.cross_entropy(logits.view(-1, VOCAB), tgt.view(-1))


torch.manual_seed(0)
model = Model().cuda()
nn.init.zeros_(model.head.weight)
for layer in model.layers:
    nn.init.zeros_(layer.proj.weight)

# leads to NaNs when batch * seq == 65536
model = torch.compile(model, dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)

# works correctly
# model = torch.compile(model, dynamic=False, fullgraph=True)

# Note: @karpathy uses bf16 for AdamW, here we use fp32
opt = torch.optim.AdamW(model.parameters(), lr=0.1, fused=True)
g = torch.Generator(device="cuda").manual_seed(1234)

for step in range(8):
    for _ in range(GRAD_ACCUM):
        x = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g)
        y = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g)
        with torch.amp.autocast("cuda", torch.bfloat16):  # unnecessary
            loss = model(x, y)
        if not math.isfinite(loss.item()):
            print(f"FAIL: NaN at step {step}")
            exit(1)
        (loss / GRAD_ACCUM).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    model.zero_grad(set_to_none=True)
    print(f"step {step}: {loss.item():.4f}")

print("OK")

Output

step 0: 11.0904
step 1: 11.0904
step 2: 11.0904
FAIL: NaN at step 3

Workarounds

Any of these prevent the issue:

  1. Use torch.compile(model, dynamic=False, fullgraph=True) without mode="max-autotune-no-cudagraphs"
  2. Keep batch_size × seq_len < 65536
  3. Use the older 25.09 container

And yes we tried the latest torch nightly, same result