Hello, I am trying to implement Multihead Self-Attention using torch.nn.functional.scaled_dot_product_attention, but I am not sure how to transform the src_key_padding_mask usually taken by the nn.TransformerEncoder to the desired attn_mask taken by SDPA. Is this implementation correct?
from torch import nn
import torch.nn.functional as F
from typing import Callable
class SelfAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
rope: Callable = None,
is_causal: bool = False,
):
super().__init__()
assert embed_dim % num_heads == 0
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.dropout = dropout
self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = embed_dim // num_heads
self.rope = rope
self.is_causal = is_causal
def forward(self, x, key_padding_mask=None):
qkv = self.qkv_proj(x)
qkv = qkv.unflatten(-1, [self.num_heads, 3 * self.head_dim])
q, k, v = qkv.chunk(3, dim=-1)
if self.rope:
q = self.rope(q)
k = self.rope(k)
attn_mask = None
if key_padding_mask is not None:
B, S = key_padding_mask.shape
attn_mask = ~key_padding_mask\
.view(B, 1, 1, S)\
.expand(-1, self.num_heads, -1, -1)
attn_output = F.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=self.is_causal if self.training else False,
)
attn_output = attn_output.transpose(1, 2).flatten(-2)
return self.out_proj(attn_output)