How does the @ sign work in this instance?

I don’t understand why this code snippet gives the output it gives. Can you explain the math behind this or post a link to someplace explaining higher dimensional matrix multiplication? Thanks!

A = torch.randn(1, 64, 1152, 1, 8)
B = torch.randn(10, 1, 1152, 8, 16)

C = A @ B
print(C.size())

output: torch.Size([10, 64, 1152, 1, 16])

5 Likes

The matrix multiplication(s) are done between the last two dimensions (1×8 @ 8×16 --> 1×16). The remaining first three dimensions are broadcast and are ‘batch’, so you get 10×64×1152 matrix multiplications.

Best regards

Thomas

14 Likes

Thank you that’s very helpful!

torch.bmm()

https://pytorch.org/docs/stable/generated/torch.bmm.html?highlight=bmm#torch.bmm

@tom Is the @ operator documented somewhere? What if A and B have different number of dimensions? What if either A or B only has one dimension? I’m trying to understand how this operator can be used to calculate the weighted average according to a formula given in this answer.

It’s a bit tricky and the link is not obvious, but in fact the A @ B operator internally maps to torch.matmul(A, B), so you can use the linked documentation for the latter.

Best regards

Thomas