FlashAttention Pytorch integration

Hi, I’m trying to experiment and make tweaks and potential upgrades to FlashAttention, and wondering where’s the best place to start. Does the Pytorch integration copy-paste/pull from the original FlashAttention repo, or there are implementation changes made along with the integration?

Thanks!

This doesn’t answer your question, but relatedly for anyone interested in making changes to attention but still want to have fast kernels, there’s pytorch/torch/nn/attention/_flex_attention.py at 93a33bf3ac0b4c9560b49780eabcad2f76dcf43e · pytorch/pytorch · GitHub