[SDPA] RTX5080 is different from CPU calculation result in backward with long seq

Code:

I registered the hook for the inputs and output of SDPA in the code and printed the maximum value of the gradient

When the std of the key is large and its length is slightly increased, abnormal gradients occur during backward on the RTX 5080, while the gradient scale remains normal on the CPU.

import torch
from torch.nn.functional import scaled_dot_product_attention as sdpa
from torch.nn.functional import mse_loss

print("torch:", torch.__version__)
print("cudnn:", torch.backends.cudnn.version())


def verify_gard(device, seq_len, k_scale):
    print('')
    print(f'device={device}',f'seq_len={seq_len}',f'k_scale={k_scale}')
    q = torch.randn((1, 16, seq_len, 128), dtype=torch.float32, device=device, requires_grad=True)
    k = torch.randn((1, 16, seq_len, 128), dtype=torch.float32, device=device, requires_grad=True) * k_scale
    v = torch.randn((1, 16, seq_len, 128), dtype=torch.float32, device=device, requires_grad=True)
    target = torch.randn(1, 16, seq_len, 128, requires_grad=False).to(device)
    dropout = 0
    scaling = 0.08838834764831845
    is_causal = True
    q.register_hook(lambda grad: print(f"{"q":<10}", torch.max(torch.abs(grad)).item()))
    k.register_hook(lambda grad: print(f"{"k":<10}", torch.max(torch.abs(grad)).item()))
    v.register_hook(lambda grad: print(f"{"v":<10}", torch.max(torch.abs(grad)).item()))
    attn_output = sdpa(
        query=q,
        key=k,
        value=v,
        attn_mask=None,
        dropout_p=dropout,
        scale=scaling,
        is_causal=is_causal,
    )
    attn_output.register_hook(lambda grad: print(f"{"out":<10}", torch.max(torch.abs(grad)).item()))
    loss = mse_loss(attn_output, target)
    loss.backward()


verify_gard('cpu', 151, 1.0)
verify_gard('cuda', 151, 1.0)

verify_gard('cpu', 151, 20.0)
verify_gard('cuda', 151, 20.0)

verify_gard('cpu', 20, 20.0)
verify_gard('cuda', 20, 20.0)

Result:

torch: 2.8.0.dev20250615+cu128
cudnn: 90701

device=cpu seq_len=151 k_scale=1.0
out        3.560339973773807e-05
q          5.226455323281698e-05
v          6.997660238994285e-05
k          6.137766467873007e-05

device=cuda seq_len=151 k_scale=1.0
out        3.566638406482525e-05
q          5.120061177876778e-05
v          6.786640733480453e-05
k          6.284983101068065e-05

device=cpu seq_len=151 k_scale=20.0
out        4.353544136392884e-05
q          0.001123013673350215
v          0.0001919580390676856
k          5.425199560704641e-05

device=cuda seq_len=151 k_scale=20.0
out        4.079606878804043e-05
q          2272.349365234375
v          60.08359146118164
k          121.35572052001953   <- abnormal

device=cpu seq_len=20 k_scale=20.0
out        0.0003049526712857187
q          0.006180753465741873
v          0.0009748771553859115
k          0.000328497146256268

device=cuda seq_len=20 k_scale=20.0
out        0.00029924517730250955
q          0.006182350218296051
v          0.0008953474462032318
k          0.0003514259879011661

Versions

PyTorch version: 2.8.0.dev20250615+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 专业版 (10.0.26100 64 位)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct  4 2024, 13:17:27) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-11-10.0.26100-SP0
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5080
Nvidia driver version: 576.52
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: AMD Ryzen 7 9700X 8-Core Processor             
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3800
MaxClockSpeed: 3800
L2CacheSize: 8192
L2CacheSpeed: None
Revision: 17408

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] mypy==1.11.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] numpydoc==1.7.0
[pip3] torch==2.8.0.dev20250615+cu128
[pip3] torchaudio==2.8.0.dev20250616+cu128
[pip3] torchvision==0.23.0.dev20250616+cu128
[pip3] triton-windows==3.3.1.post19
[conda] _anaconda_depends         2024.10             py312_mkl_0  
[conda] blas                      1.0                         mkl  
[conda] mkl                       2023.1.0         h6b88ed4_46358  
[conda] mkl-service               2.4.0           py312h2bbff1b_1  
[conda] mkl_fft                   1.3.10          py312h827c3e9_0  
[conda] mkl_random                1.2.7           py312h0158946_0  
[conda] numpy                     1.26.4          py312hfd52020_0  
[conda] numpy-base                1.26.4          py312h4dde369_0  
[conda] numpydoc                  1.7.0           py312haa95532_0  
[conda] torch                     2.8.0.dev20250615+cu128          pypi_0    pypi
[conda] torchaudio                2.8.0.dev20250616+cu128          pypi_0    pypi
[conda] torchvision               0.23.0.dev20250616+cu128          pypi_0    pypi
[conda] triton-windows            3.3.1.post19             pypi_0    pypi