How is dim for torch.matmul determined?

I know that the attention code uses mutuml like this and that this code works
But using matmul under the exact same conditions gives an error Why is this?
I want to control dim to account for batch and attention_haed, but how do I control the calculated dimension of matmul?
attencode and print output

        attn = self.dropout(attn)

        context = torch.matmul(attn, value).transpose(1, 2)
        context = context.contiguous().view(batch_size, -1, self.d_model)

attn torch.Size([4, 16, 100, 100])
torch.Size([4, 16, 100, 32])

my test code and error

v=torch.randn(4, 16, 100, 100)
at=torch.randn(4, 16, 100,32)
context = torch.matmul(at, v)
Traceback (most recent call last):
  File "", line 129, in <module>
    context = torch.matmul(at, v)
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [64, 32] but got: [64, 100].

The error message is a bit confusing here because it comes from torch.bmm that matmul reduces to (I think), but the problem is that the second to last dimension of the right argument (v) should match the last dimension of the left (at), as this is the contraction dimension.

Best regards


Just adding to @tom’s answer: You could use torch.einsum() to write more readable code.

# attention operation
q = torch.randn(4, 16, 100, 32) # batch, head, N, dim
k = torch.randn(4, 16, 100, 32)
v = torch.randn(4, 16, 100, 32)

attention_logits = torch.einsum("bhnd,bhmd->bhnm", q, k)   # 4, 16, 100, 100
attention_logits = attention_logits.softmax(dim=-1)
context = torch.einsum("bhnm,bhmd->bhnd", attention_logits, v)
1 Like