I don’t understand why this code snippet gives the output it gives. Can you explain the math behind this or post a link to someplace explaining higher dimensional matrix multiplication? Thanks!

A = torch.randn(1, 64, 1152, 1, 8)
B = torch.randn(10, 1, 1152, 8, 16)
C = A @ B
print(C.size())

The matrix multiplication(s) are done between the last two dimensions (1×8 @ 8×16 --> 1×16). The remaining first three dimensions are broadcast and are ‘batch’, so you get 10×64×1152 matrix multiplications.

@tom Is the @ operator documented somewhere? What if A and B have different number of dimensions? What if either A or B only has one dimension? I’m trying to understand how this operator can be used to calculate the weighted average according to a formula given in this answer.

It’s a bit tricky and the link is not obvious, but in fact the A @ B operator internally maps to torch.matmul(A, B), so you can use the linked documentation for the latter.