Attn_mask implementation in torch.nn.functional

in _scaled_dot_product_attention( of torch.nn.function, attn_mask is used with

4851     if attn_mask is not None:
4852         attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))

since baddbmm is essentially adding attn_mask to q @ k.T, I was wondering how this does what the documentation says attn_musk does? (" attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape")

Thanks!