I don’t understand why this code snippet gives the output it gives. Can you explain the math behind this or post a link to someplace explaining higher dimensional matrix multiplication? Thanks!
A = torch.randn(1, 64, 1152, 1, 8)
B = torch.randn(10, 1, 1152, 8, 16)
C = A @ B
output: torch.Size([10, 64, 1152, 1, 16])
The matrix multiplication(s) are done between the last two dimensions (1×8 @ 8×16 --> 1×16). The remaining first three dimensions are broadcast and are ‘batch’, so you get 10×64×1152 matrix multiplications.
Thank you that’s very helpful!
@tom Is the
@ operator documented somewhere? What if
B have different number of dimensions? What if either
B only has one dimension? I’m trying to understand how this operator can be used to calculate the weighted average according to a formula given in this answer.
It’s a bit tricky and the link is not obvious, but in fact the A @ B operator internally maps to torch.matmul(A, B), so you can use the linked documentation for the latter.