Hi,
Using the latest pytorch 2.2.0, installed from pip, running on AMD MI250X on a HPC system, I find that FlashAttention does not work:
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5476: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:320.)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5476: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:416.)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5476: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:418.)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1 train loop on rank=0: 0%| | 0/45943 [00:01<?, ?it/s]
...
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
RuntimeError: No available kernel. Aborting execution.
When I check the GCN arch and whether flash sdp is enabled, I get:
>>> print("GCN arch:", torch.cuda.get_device_properties('cuda').gcnArchName)
GCN arch: gfx90a:sramecc+:xnack-
>>> print("FlashAttention available:", torch.backends.cuda.flash_sdp_enabled())
FlashAttention available: True
>>> torch.__version__
'2.2.0+rocm5.7'
I see that FlashAttention support for ROCm was being worked on:
pytorch:main
← ROCm:xinyazhang/up-fa-mathaot
opened 10:25PM - 21 Nov 23 UTC
This pull requests add initial Flash Attention support for AMD/ROCM platform. It… added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.
Fixes #112997
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang
but it was reverted, and work was done to readd it, although the PR was closed before merging (?):
pytorch:main
← ROCm:xinyazhang/up-fa-mathaot
opened 09:25AM - 16 Dec 23 UTC
Note about the Updates:
This PR:
1. skips more flash attention related UTs o… n MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.
CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.
Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.
Fixes #112997
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler
so I’m not sure if this is supposed to work yet or not with pytorch 2.2.0?
Any AMD folks (@xinyazhang @jithunnair-amd ) can confirm?
Thanks!