Flex Attention Score_mod with learnable parameters

Hello!,

i want to use the new flex_attention function to implement a biased attention. At my case i have a graph transformer and want to modify the attention scores based on the shortest distances between nodes. I have an Embedding module that has a learned parameter for each distance. My question is:
is it possible to call an external module at my case the Embedding module that has a learned parameter for each distance from the score_mod fuction or is it another way to do this?
Thank you in Advance :slight_smile:

here is a code to demonstrate:

class GraphAttnBias(nn.Module):
    def __init__(self,  cfg):
        super().__init__()
        self.attn_bias = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.num_heads)
 
def forward(self,idx):
        return self.attn_bias(idx)

class CausalFlexAttention(nn.Module):
    
    def __init__(self, cfg, embed_dim, num_heads, dropout, bias):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.c_attn = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_head = num_heads
        self.n_embd = embed_dim
        self.dropout = dropout
       
        # stuff for attention bias
        attn_bias = GraphAttnBias(cfg)
    
    def score_mod(score, batch, head, q_idx, k_idx):
    #here i want to call the embedding module with the distances like this 
    assuming distance has shape (Batch_length* seq_length * seq_length)

        return score + attn_bias(distance[batch,q_idx, k_idx])

    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx
    
    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        
        y = flex_attention(
            q,
            k,
            v,
            score_mod= CausalSelfAttention.score_mod,
            block_mask= CausalSelfAttention.casual_mask
        )

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y