Hi all!
I was computing flash attention in my model implementation and I was just wondering if there is any way of getting the attention weights that are computed in torch.nn.functional.scaled_dot_product_attention. From the framework itself, it does not seem to be any possibility.
I wondered also if I could just implement this function in my model (implementation is defined in torch.nn.functional.scaled_dot_product_attention — PyTorch 2.2 documentation) and just use the sdp_kernel and get the same results (which will enable me to retrieve the attention weights from the function).
What I current facing is that if I just do the torch.enisum or torch.dot will cause OOM, I think is the F.scaled_dot_product_attention will do flash attention or something to optimize the memory
A large part of the memory (and time) savings of flash attention is avoiding to materialize the attention matrix.
Some things you can do (if you have the time):
be sure to cast q and k to small precision before the einsum and maybe detach.
You could also do the computation head by head.
Instead of doing the softmax, you could
Compute the log_sum_exp and subtract the value inplace (so the logsumexp is 0, i.e. you have done log_softmax inplace,
One method I used is that I am actually not interested to the whole NxN attention map, and I am only interested in kxN where k is the number of interested tokens. Extract them first will save memory and time a lot if k << N