In a recent PyTorch version (since when exactly?), to use an efficient attention implementation, you can simply use torch.nn.functional.scaled_dot_product_attention
, right? As I understand, it would automatically use FlashAttention-2:
automatically select the most optimal implementation based on the inputs
I’m not sure exactly what this means though. How exactly is the logic? In what cases would it select FlashAttention-2?
Also, as far as I know, FlashAttention-2 only works on more recent Nvidia GPUs but does not on older. (Since what kind GPU? I guess 1080 not? 2080?)
Alternatively, there is also Memory-Efficient Attention. How does this compare in speed? This works on a 1080?
What is the most efficient implementation for a 1080, or what would be a reasonable choice there? Should I just use torch.nn.functional.scaled_dot_product_attention
or sth more custom?
I have seen many other custom implementations, for example in segment_anything_fast. I’m not really sure if this is outdated (i.e. obsolete with recent PyTorch) or still makes sense. I think I also have seen a fast Triton implementation somewhere.
In some cases, I also need self-attention with relative positional encoding. As far as I understand, torch.nn.functional.scaled_dot_product_attention
should support that. But I can imagine that not all implementations would support it. I also need both cross-attention and self-attention. (But it would be ok for me to use different implementations for each of those cases.)
Does torch.nn.functional.scaled_dot_product_attention
work for training as well, i.e. the gradient is defined?
I saw that torch.nn.MultiheadAttention
for some reason does not use the native attention function when training is enabled. Why?