I have a question regarding the flash attention implementation of Pytorch 2.0.
Does PyTorch ‘only’ implement the fused kernel, i.e. all self-attention operations with a single read/write to HBM?
Or, does it also implement tiling and checkpointing/recomputation during backprop as in the Flash Attention paper? I’m not a hardware guy and this question might be simple, but it was not obvious to me when I read through the paper and the docs.
Another question that is more related to Flash Attention itself: Why is memory usage linear in sequence length during the forward pass? I was expecting sublinear due to tiling, but probably I’m missing something.
Thanks in advance