Hi ! the last axis of your mask tensor should always match the 3rd axis or 2nd from the last axis of your k tensor.
torch.manual_seed(0)
k = torch.randn(2, 4, 4, 20).cuda()
mask = torch.ones(2, 4, 1, 4).cuda()
result = F.scaled_dot_product_attention(k, k, k, mask)
There is no shape restriction, you just need at least to have the same dimensions of tensors for your key, query and values that you pass on your function. You can refer to this link in terms of the implementation of the F.scaled_dot_product_attention.