Correct SDPA's attn_mask for Self-attention

Hello, I am trying to implement Multihead Self-Attention using torch.nn.functional.scaled_dot_product_attention, but I am not sure how to transform the src_key_padding_mask usually taken by the nn.TransformerEncoder to the desired attn_mask taken by SDPA. Is this implementation correct?

from torch import nn
import torch.nn.functional as F
from typing import Callable

class SelfAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        bias: bool = True,
        rope: Callable = None,
        is_causal: bool = False,
    ):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = dropout
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        self.rope = rope
        self.is_causal = is_causal

    def forward(self, x, key_padding_mask=None):
        qkv = self.qkv_proj(x)
        qkv = qkv.unflatten(-1, [self.num_heads, 3 * self.head_dim])
        q, k, v = qkv.chunk(3, dim=-1)
        if self.rope:
            q = self.rope(q)
            k = self.rope(k)

        attn_mask = None
        if key_padding_mask is not None:
            B, S = key_padding_mask.shape
            attn_mask = ~key_padding_mask\
                .view(B, 1, 1, S)\
                .expand(-1, self.num_heads, -1, -1)

        attn_output = F.scaled_dot_product_attention(
            query=q.transpose(1, 2),
            key=k.transpose(1, 2),
            value=v.transpose(1, 2),
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=self.is_causal if self.training else False,
        )

        attn_output = attn_output.transpose(1, 2).flatten(-2)

        return self.out_proj(attn_output)