I’m trying to run the ‘FlashAttention’ variant of the F.scaled_dot_product_attention.
Config = namedtuple(‘FlashAttentionConfig’, [‘enable_flash’, ‘enable_math’, ‘enable_mem_efficient’])’
self.cuda_config = Config(True, False, False)
with torch.backends.cuda.sdp_kernel(**self.cuda_config._asdict()):
x = F.scaled_dot_product_attention(q, k, v)
I am on A100-SXM,
Tried running this with
CUDA version 12.0 and PyTorch 2.1.0.dev20230526+cu121
CUDA 11.7 and PyTorch 2.0.1
I see no references to this error in general, and not sure what I’m doing wrong?
It works just fine with
Config(False, True, True) → which uses math and memory efficient attention. But I would preferably use Flash Attention.
Thank you, its working now. It says ‘reduction over non-contiguous data’ frequently but it seems to be working.
If you don’t mind I have two follow up questions:
Does it hold that the Flash Attention implementation available in PyTorch is only usable with float16/bfloat16 like the original repo’s implementation? Or would it work with Float32 as well?
Is this fully compatible with torch.compile?
For those interested, this error no longer occurs after updating to PyTorch nightly version 2.1.0.dev20230527+cu121.