Hi, I am building a project where the attention matrix is sparse and I got a huge performance boost from using FlexAttention by specifying block_mask.
However, my project also requires a highly customized add_score matrix to be added to the score matrix. The add_score matrix of shape [batch][num_heads][seq_len][seq_len] is fully pre-computed by the dataloader and its value cannot easily determined by simple element-wise operator on arguments b, h, q_idx, and kv_idx.
I tried using the following the it won’t compile
def score_mod(score, b, h, q_idx, kv_idx):
return score_add[b, h, q_idx, kv_idx]
Do we have any good solutions for this?