Code:
I registered the hook for the inputs and output of SDPA in the code and printed the maximum value of the gradient
When the std of the key
is large and its length is slightly increased, abnormal gradients occur during backward on the RTX 5080, while the gradient scale remains normal on the CPU.
import torch
from torch.nn.functional import scaled_dot_product_attention as sdpa
from torch.nn.functional import mse_loss
print("torch:", torch.__version__)
print("cudnn:", torch.backends.cudnn.version())
def verify_gard(device, seq_len, k_scale):
print('')
print(f'device={device}',f'seq_len={seq_len}',f'k_scale={k_scale}')
q = torch.randn((1, 16, seq_len, 128), dtype=torch.float32, device=device, requires_grad=True)
k = torch.randn((1, 16, seq_len, 128), dtype=torch.float32, device=device, requires_grad=True) * k_scale
v = torch.randn((1, 16, seq_len, 128), dtype=torch.float32, device=device, requires_grad=True)
target = torch.randn(1, 16, seq_len, 128, requires_grad=False).to(device)
dropout = 0
scaling = 0.08838834764831845
is_causal = True
q.register_hook(lambda grad: print(f"{"q":<10}", torch.max(torch.abs(grad)).item()))
k.register_hook(lambda grad: print(f"{"k":<10}", torch.max(torch.abs(grad)).item()))
v.register_hook(lambda grad: print(f"{"v":<10}", torch.max(torch.abs(grad)).item()))
attn_output = sdpa(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output.register_hook(lambda grad: print(f"{"out":<10}", torch.max(torch.abs(grad)).item()))
loss = mse_loss(attn_output, target)
loss.backward()
verify_gard('cpu', 151, 1.0)
verify_gard('cuda', 151, 1.0)
verify_gard('cpu', 151, 20.0)
verify_gard('cuda', 151, 20.0)
verify_gard('cpu', 20, 20.0)
verify_gard('cuda', 20, 20.0)
Result:
torch: 2.8.0.dev20250615+cu128
cudnn: 90701
device=cpu seq_len=151 k_scale=1.0
out 3.560339973773807e-05
q 5.226455323281698e-05
v 6.997660238994285e-05
k 6.137766467873007e-05
device=cuda seq_len=151 k_scale=1.0
out 3.566638406482525e-05
q 5.120061177876778e-05
v 6.786640733480453e-05
k 6.284983101068065e-05
device=cpu seq_len=151 k_scale=20.0
out 4.353544136392884e-05
q 0.001123013673350215
v 0.0001919580390676856
k 5.425199560704641e-05
device=cuda seq_len=151 k_scale=20.0
out 4.079606878804043e-05
q 2272.349365234375
v 60.08359146118164
k 121.35572052001953 <- abnormal
device=cpu seq_len=20 k_scale=20.0
out 0.0003049526712857187
q 0.006180753465741873
v 0.0009748771553859115
k 0.000328497146256268
device=cuda seq_len=20 k_scale=20.0
out 0.00029924517730250955
q 0.006182350218296051
v 0.0008953474462032318
k 0.0003514259879011661
Versions
PyTorch version: 2.8.0.dev20250615+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 专业版 (10.0.26100 64 位)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct 4 2024, 13:17:27) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-11-10.0.26100-SP0
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5080
Nvidia driver version: 576.52
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Name: AMD Ryzen 7 9700X 8-Core Processor
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3800
MaxClockSpeed: 3800
L2CacheSize: 8192
L2CacheSpeed: None
Revision: 17408
Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] mypy==1.11.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] numpydoc==1.7.0
[pip3] torch==2.8.0.dev20250615+cu128
[pip3] torchaudio==2.8.0.dev20250616+cu128
[pip3] torchvision==0.23.0.dev20250616+cu128
[pip3] triton-windows==3.3.1.post19
[conda] _anaconda_depends 2024.10 py312_mkl_0
[conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h6b88ed4_46358
[conda] mkl-service 2.4.0 py312h2bbff1b_1
[conda] mkl_fft 1.3.10 py312h827c3e9_0
[conda] mkl_random 1.2.7 py312h0158946_0
[conda] numpy 1.26.4 py312hfd52020_0
[conda] numpy-base 1.26.4 py312h4dde369_0
[conda] numpydoc 1.7.0 py312haa95532_0
[conda] torch 2.8.0.dev20250615+cu128 pypi_0 pypi
[conda] torchaudio 2.8.0.dev20250616+cu128 pypi_0 pypi
[conda] torchvision 0.23.0.dev20250616+cu128 pypi_0 pypi
[conda] triton-windows 3.3.1.post19 pypi_0 pypi