I saw the newly released Flex Attention FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention | PyTorch and I have a question.
I am wondering if it is possible to write to some globally scoped tensor the way that the alibi bias example in the link above reads from a globally scoped tensor.
Say I wanted to retrieve all scores from the model to plot the attention matrix, or maybe sum the columns of the attention matrix to modify a KV cache eviction policy, or some other use case.
Is it possible to accomplish this by writing to a globally initialized tensor, I tried the following, but it didn’t work. Is there a way to accomplish this with flex attention?
import torch
from torch.nn.attention.flex_attention import flex_attention
query = torch.randn(1, 8, 256, 128)
key = torch.randn(1, 8, 128, 128)
value = torch.randn(1, 8, 128, 128)
scores_out = torch.zeros(1, 8, 256, 128)
def noop(score, b, h, q_idx, kv_idx):
scores_out[b, h, q_idx, kv_idx] = score
return score
out = flex_attention(query, key, value, score_mod=noop)
print(out.size())
print(scores_out)