How does the @ sign work in this instance?

I don’t understand why this code snippet gives the output it gives. Can you explain the math behind this or post a link to someplace explaining higher dimensional matrix multiplication? Thanks!

A = torch.randn(1, 64, 1152, 1, 8)
B = torch.randn(10, 1, 1152, 8, 16)

C = A @ B
print(C.size())

output: torch.Size([10, 64, 1152, 1, 16])

6 Likes

The matrix multiplication(s) are done between the last two dimensions (1×8 @ 8×16 --> 1×16). The remaining first three dimensions are broadcast and are ‘batch’, so you get 10×64×1152 matrix multiplications.

Best regards

Thomas

17 Likes

Thank you that’s very helpful!

torch.bmm()

https://pytorch.org/docs/stable/generated/torch.bmm.html?highlight=bmm#torch.bmm

@tom Is the @ operator documented somewhere? What if A and B have different number of dimensions? What if either A or B only has one dimension? I’m trying to understand how this operator can be used to calculate the weighted average according to a formula given in this answer.

It’s a bit tricky and the link is not obvious, but in fact the A @ B operator internally maps to torch.matmul(A, B), so you can use the linked documentation for the latter.

Best regards

Thomas

Is there a class or tutorial that talks about this? I guess this is math? I didn’t see this on my linear algebra class.
Thank you.

To my mind, the trouble of maths lectures is that of all the explanations of a given thing, the subset of those that resonate with the student is very individual and whether the explanation presented in a class is one of resonating ones for you is a bit of a chance thing.

That said, I like to identify n x m matrices with linear functions mapping R^m to R^n by multiplication to the left. Then when you spell out the composition of these functions and believe that it is represented by the product matrix, this tells you how the dimensions have to be in matrix multiplication.

Best regards

Thomas

Small correction for people that stumble upon this Thread: A @ B calls torch.matmul(B,A)

I don’t think so:

A = torch.randn(1, 2)
B = torch.randn(2, 3)

# works
C = torch.matmul(A, B)
print(C.shape)
# torch.Size([1, 3])

# works
C = A @ B
print(C.shape)
# torch.Size([1, 3])

# breaks
torch.matmul(B, A)
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 1x2)

# breaks
B @ A
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 1x2)
1 Like

Confirming @ptrblck’s comment. A@B is in fact equal to torch.matmul(A,B) and not matmul(B,A)