Sizes do not match in scaled_dot_product_attention

The following code is expected to work, but it is throwing an error.

q, k, v = q.reshape(-1, seq_len, self.num_attention_heads, self.attention_head_size).transpose(1, 2), \
      k.reshape(-1, seq_len, self.num_attention_heads, self.attention_head_size).transpose(1, 2), \
      v.reshape(-1, seq_len, self.num_attention_heads, self.attention_head_size).transpose(1, 2)

print(q.shape, k.shape, v.shape)

output = nn.functional.scaled_dot_product_attention(
   q, k, v, attn_mask, self.config.attention_probs_dropout_prob
)
torch.Size([80, 8, 128, 64]) torch.Size([80, 8, 128, 64]) torch.Size([80, 8, 128, 64])

  File "", line 81, in forward
    output = nn.functional.scaled_dot_product_attention(
RuntimeError: The size of tensor a (128) must match the size of tensor b (80) at non-singleton dimension 2

I am using pytorch 2.0.1.

Could you post more details about your setup such as the specific GPU you are using and a runnable reproducible example? I could not reproduce the error with:

import torch
from torch import nn

q = torch.randn(80, 128, 8, 64, device='cuda', dtype=torch.half).transpose(1, 2)
k = torch.randn(80, 128, 8, 64, device='cuda', dtype=torch.half).transpose(1, 2)
v = torch.randn(80, 128, 8, 64, device='cuda', dtype=torch.half).transpose(1, 2)

print(q.shape, k.shape, v.shape)

attn_mask = torch.ones(128, 128, device='cuda', dtype=torch.half)

output = nn.functional.scaled_dot_product_attention(
   q, k, v, attn_mask=attn_mask)
torch.Size([80, 8, 128, 64]) torch.Size([80, 8, 128, 64]) torch.Size([80, 8, 128, 64])

EDIT: could it be that your attention mask also needs to be [128, 128] when one of the dimensions is incorrectly set to 80?