I understand that the flash attention /sdpa kernel needs to store the softmax_lse value (log sum exp value) for backward usage when training.
however, in my case, I need to fetch this value in the forward pass as well. Is there a way to do this?
I understand that the flash attention /sdpa kernel needs to store the softmax_lse value (log sum exp value) for backward usage when training.
however, in my case, I need to fetch this value in the forward pass as well. Is there a way to do this?