Flex Attention return score values

I saw the newly released Flex Attention FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention | PyTorch and I have a question.

I am wondering if it is possible to write to some globally scoped tensor the way that the alibi bias example in the link above reads from a globally scoped tensor.

Say I wanted to retrieve all scores from the model to plot the attention matrix, or maybe sum the columns of the attention matrix to modify a KV cache eviction policy, or some other use case.

Is it possible to accomplish this by writing to a globally initialized tensor, I tried the following, but it didn’t work. Is there a way to accomplish this with flex attention?

import torch
from torch.nn.attention.flex_attention import flex_attention

query = torch.randn(1, 8, 256, 128)
key = torch.randn(1, 8, 128, 128)
value = torch.randn(1, 8, 128, 128)

scores_out = torch.zeros(1, 8, 256, 128)

def noop(score, b, h, q_idx, kv_idx):
    scores_out[b, h, q_idx, kv_idx] = score
    return score

out = flex_attention(query, key, value, score_mod=noop)

print(out.size())
print(scores_out)

started a discussion at: Writing to a globally scoped tensor from score_mod function · Issue #19 · pytorch-labs/attention-gym · GitHub because it looks like a more appropriate place to discuss this topic.