Get softmax_lse value for sdpa kernel?

I understand that the flash attention /sdpa kernel needs to store the softmax_lse value (log sum exp value) for backward usage when training.

however, in my case, I need to fetch this value in the forward pass as well. Is there a way to do this?