Retrieving Attention Weights from scaled_dot_product_attention

Hi all!
I was computing flash attention in my model implementation and I was just wondering if there is any way of getting the attention weights that are computed in torch.nn.functional.scaled_dot_product_attention. From the framework itself, it does not seem to be any possibility.
I wondered also if I could just implement this function in my model (implementation is defined in torch.nn.functional.scaled_dot_product_attention — PyTorch 2.2 documentation) and just use the sdp_kernel and get the same results (which will enable me to retrieve the attention weights from the function).

Thank you!

@anto1481 Did you manage to do it? How did you end up retrieving the attention weights?
Thanks.

@Miquel_Espinosa @anto1481 Did you manage to do it? I also face this problem but don’t know how to solve it…

I think I ended up doing something like this…

attention_scores = torch.einsum('bhqd,bhkd->bhqk', q, k)
self.attn_weights = F.softmax(attention_scores / math.sqrt(q.size(-1)), dim=-1)

What I current facing is that if I just do the torch.enisum or torch.dot will cause OOM, I think is the F.scaled_dot_product_attention will do flash attention or something to optimize the memory

A large part of the memory (and time) savings of flash attention is avoiding to materialize the attention matrix.
Some things you can do (if you have the time):

  • be sure to cast q and k to small precision before the einsum and maybe detach.
  • You could also do the computation head by head.
  • Instead of doing the softmax, you could
    • Compute the log_sum_exp and subtract the value inplace (so the logsumexp is 0, i.e. you have done log_softmax inplace,
    • compute the exp inplace

what I currently want is the query @ key before the softmax,so is there a best way to get it?

One method I used is that I am actually not interested to the whole NxN attention map, and I am only interested in kxN where k is the number of interested tokens. Extract them first will save memory and time a lot if k << N