Runtime Error when using Flash Attention

I ran into the following runtime error when trying to use Flash Attention for scaled_dot_product_attention. I tried getting PyTorch from “stable (conda)”, “nightly (conda)”, and “compile from source”, and all gave me the same error.

Thanks for any advice in advance!

Minimal code to reproduce:

import torch
Q = torch.zeros(3, 10, 128).cuda()
K = torch.zeros(3, 10, 128).cuda()
V = torch.zeros(3, 10, 128).cuda()
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False,
):
    print(torch.nn.functional.scaled_dot_product_attention(Q, K, V).shape)

Error log:

test_sdp.py:10: UserWarning: Memory efficient kernel not used because: (Triggered internally at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680527322149/work/aten/src/ATen/native/transformers/cuda/sdp_utils.h:527.)
  print(torch.nn.functional.scaled_dot_product_attention(Q, K, V).shape)
Traceback (most recent call last):
  File "/data/home/cywu/test_sdp.py", line 10, in <module>
    print(torch.nn.functional.scaled_dot_product_attention(Q, K, V).shape)
RuntimeError: Torch was not compiled with flash attention.

Versions
Collecting environment information…
PyTorch version: 2.0.0.post200
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.2
Libc version: glibc-2.31

Python version: 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:39:03) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1030-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 525.85.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 1263.993
BogoMIPS: 5999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] pytorch3d==0.7.3
[pip3] torch==2.0.0.post200
[pip3] torchaudio==2.0.0
[pip3] torchvision==0.14.1a0+59d9189
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] libmagma 2.7.1 hc72dce7_1 conda-forge
[conda] libmagma_sparse 2.7.1 hc72dce7_2 conda-forge
[conda] magma 2.7.1 ha770c72_2 conda-forge
[conda] mkl 2022.2.1 h84fe81f_16997 conda-forge
[conda] numpy 1.24.2 py39h7360e5f_0 conda-forge
[conda] pytorch 2.0.0 cuda112py39ha9981d0_200 conda-forge
[conda] pytorch-cuda 11.8 h7e8668a_3 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch3d 0.7.3 dev_0
[conda] torch 2.1.0a0+git9a2a6fc dev_0
[conda] torchaudio 2.0.0 py39_cu118 pytorch
[conda] torchvision 0.14.1 cuda112py39hb350dc8_1 conda-forge

It seems you are using unofficial conda binaries from conda-forge created by mark.harfouche, which do not seem to ship with FlashAttention.
You can see it by the custom tag:

PyTorch version: 2.0.0.post200
...
[conda] pytorch 2.0.0 cuda112py39ha9981d0_200 conda-forge

as well as an old CUDA runtime, which we are not using anymore for our current builds:

CUDA used to build PyTorch: 11.2