The following code is expected to work, but it is throwing an error.
q, k, v = q.reshape(-1, seq_len, self.num_attention_heads, self.attention_head_size).transpose(1, 2), \
k.reshape(-1, seq_len, self.num_attention_heads, self.attention_head_size).transpose(1, 2), \
v.reshape(-1, seq_len, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
print(q.shape, k.shape, v.shape)
output = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask, self.config.attention_probs_dropout_prob
)
torch.Size([80, 8, 128, 64]) torch.Size([80, 8, 128, 64]) torch.Size([80, 8, 128, 64])
File "", line 81, in forward
output = nn.functional.scaled_dot_product_attention(
RuntimeError: The size of tensor a (128) must match the size of tensor b (80) at non-singleton dimension 2
I am using pytorch 2.0.1.