I don’t understand why this code snippet gives the output it gives. Can you explain the math behind this or post a link to someplace explaining higher dimensional matrix multiplication? Thanks!

A = torch.randn(1, 64, 1152, 1, 8)
B = torch.randn(10, 1, 1152, 8, 16)
C = A @ B
print(C.size())

The matrix multiplication(s) are done between the last two dimensions (1×8 @ 8×16 --> 1×16). The remaining first three dimensions are broadcast and are ‘batch’, so you get 10×64×1152 matrix multiplications.

@tom Is the @ operator documented somewhere? What if A and B have different number of dimensions? What if either A or B only has one dimension? I’m trying to understand how this operator can be used to calculate the weighted average according to a formula given in this answer.

It’s a bit tricky and the link is not obvious, but in fact the A @ B operator internally maps to torch.matmul(A, B), so you can use the linked documentation for the latter.

To my mind, the trouble of maths lectures is that of all the explanations of a given thing, the subset of those that resonate with the student is very individual and whether the explanation presented in a class is one of resonating ones for you is a bit of a chance thing.

That said, I like to identify n x m matrices with linear functions mapping R^m to R^n by multiplication to the left. Then when you spell out the composition of these functions and believe that it is represented by the product matrix, this tells you how the dimensions have to be in matrix multiplication.