I’m trying to run the ‘FlashAttention’ variant of the F.scaled_dot_product_attention.
Config = namedtuple(‘FlashAttentionConfig’, [‘enable_flash’, ‘enable_math’, ‘enable_mem_efficient’])’
self.cuda_config = Config(True, False, False)
with torch.backends.cuda.sdp_kernel(**self.cuda_config._asdict()):
x = F.scaled_dot_product_attention(q, k, v)
I am on A100-SXM,
Tried running this with
CUDA version 12.0 and PyTorch 2.1.0.dev20230526+cu121
CUDA 11.7 and PyTorch 2.0.1
I see no references to this error in general, and not sure what I’m doing wrong?
It works just fine with
Config(False, True, True) → which uses math and memory efficient attention. But I would preferably use Flash Attention.
Thank you, its working now. It says ‘reduction over non-contiguous data’ frequently but it seems to be working.
If you don’t mind I have two follow up questions:
Does it hold that the Flash Attention implementation available in PyTorch is only usable with float16/bfloat16 like the original repo’s implementation? Or would it work with Float32 as well?
Is this fully compatible with torch.compile?
For those interested, this error no longer occurs after updating to PyTorch nightly version 2.1.0.dev20230527+cu121.
I get the same error with Pytorch 2.3.0+cu121 built from source. What could be the possible reason?
I’m running on a cluster of Tesla V100-SXM2-32GB with CUDA version: 12.4.
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
attn_out = F.scaled_dot_product_attention(xq, keys, values)
Based in this code Volta GPUs are not supported.
I would generally not recommend trying to force a specific algorithm, but let PyTorch select the fastest one for the used device.
Yes, it works if I get Pytorch choose the best algorithm. But flash attention alone does seem to work as it does not support an attention mask separately.