Retrieving Attention Weights from scaled_dot_product_attention

Hi all!
I was computing flash attention in my model implementation and I was just wondering if there is any way of getting the attention weights that are computed in torch.nn.functional.scaled_dot_product_attention. From the framework itself, it does not seem to be any possibility.
I wondered also if I could just implement this function in my model (implementation is defined in torch.nn.functional.scaled_dot_product_attention — PyTorch 2.2 documentation) and just use the sdp_kernel and get the same results (which will enable me to retrieve the attention weights from the function).

Thank you!