Hello!,
i want to use the new flex_attention function to implement a biased attention. At my case i have a graph transformer and want to modify the attention scores based on the shortest distances between nodes. I have an Embedding module that has a learned parameter for each distance. My question is:
is it possible to call an external module at my case the Embedding module that has a learned parameter for each distance from the score_mod fuction or is it another way to do this?
Thank you in Advance
here is a code to demonstrate:
class GraphAttnBias(nn.Module):
def __init__(self, cfg):
super().__init__()
self.attn_bias = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.num_heads)
def forward(self,idx):
return self.attn_bias(idx)
class CausalFlexAttention(nn.Module):
def __init__(self, cfg, embed_dim, num_heads, dropout, bias):
super().__init__()
assert embed_dim % num_heads == 0
self.c_attn = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.n_head = num_heads
self.n_embd = embed_dim
self.dropout = dropout
# stuff for attention bias
attn_bias = GraphAttnBias(cfg)
def score_mod(score, batch, head, q_idx, k_idx):
#here i want to call the embedding module with the distances like this
assuming distance has shape (Batch_length* seq_length * seq_length)
return score + attn_bias(distance[batch,q_idx, k_idx])
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = flex_attention(
q,
k,
v,
score_mod= CausalSelfAttention.score_mod,
block_mask= CausalSelfAttention.casual_mask
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y