FlashAttention on AMD MI250X not working with pytorch 2.2.0

Hi,

Using the latest pytorch 2.2.0, installed from pip, running on AMD MI250X on a HPC system, I find that FlashAttention does not work:

/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5476: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:320.)
  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5476: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:416.)
  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:5476: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:418.)
  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1 train loop on rank=0:   0%|          | 0/45943 [00:01<?, ?it/s]
...
    attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
RuntimeError: No available kernel. Aborting execution.

When I check the GCN arch and whether flash sdp is enabled, I get:

>>> print("GCN arch:", torch.cuda.get_device_properties('cuda').gcnArchName)
GCN arch: gfx90a:sramecc+:xnack-
>>> print("FlashAttention available:", torch.backends.cuda.flash_sdp_enabled())
FlashAttention available: True
>>> torch.__version__
'2.2.0+rocm5.7'

I see that FlashAttention support for ROCm was being worked on:

but it was reverted, and work was done to readd it, although the PR was closed before merging (?):

so I’m not sure if this is supposed to work yet or not with pytorch 2.2.0?
Any AMD folks (@xinyazhang @jithunnair-amd) can confirm?

Thanks!

Answering myself in case anyone else stumbles here.

As far as I can tell, the pytorch wheels supplied at https://download.pytorch.org/whl/rocm5.7 do not have Flash Attention for ROCM/AMD builtin so it’s not callable via torch.nn.functional.scaled_dot_product_attention directly.

I was successful in installing ROCmSoftwarePlatform/flash-attention.git on top of the docker image rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 and thus replacing torch.nn.MultiheadAttention -> flash_attn.modules.mha.MHA, but this means that extra steps still have to be taken for running Flash Attention on AMD and there is no full API compatibility as of 2.2.0.

Hopefully subsequent releases will have Flash Attention for AMD builtin… :slight_smile:

1 Like