Flex_attention returning logits

I am running some experiments where I need to use attention logits (attention tensor pre-softmax). One example would be CoPE ([2405.18719] Contextual Position Encoding: Learning to Count What's Important)

Any chance that we can get an option to skip softmax?