Question about 5-d tensor multiplication by @ operation

Hello guys,

I came across an interesting operation and am curious about how it works.

aa = torch.rand([1, 100, 1152, 1, 8])
bb = torch.rand([10, 1, 1152, 8, 16])
cc = aa @ bb
print(cc.size())

output is

torch.Size([10, 100, 1152, 1, 16])

My question is how @ (matrix multiplication operator) perform on the two tensors? It seems that it does two sub-matrix-multiplications, one is 10x1 multiplies with 1x100 and the other one is 1x8 multiplies with 8x16.

Who can give more explanation for me?

Thanks in advance.

Hi,

It performs the same thing as the .matmul() operation that you can find here in the doc.
In that case, it performs a matrix matrix multiplication with the last 2 dimensions of each tensor: 1x8 * 8x16 and considers all others dimensions as batch. If a dimension is 1 for one tensor and something else for the other (in any of the batch dimensions), it is expanded.
So the operation you do here is the same as:

# Expand batch dimensions
aa = aa.expand(10, 100, 1152, 1, 8)
bb = bb.expand(10, 100, 1152. 8 16)
# Collapse all extra dimensions as a single batch
aa = aa.view(-1, 1, 8)
bb = bb.view(-1, 8, 16)
# Do the batched mm operaion
cc = torch.bmm(aa, bb)
# Reshape the output along batch dimensions
cc = cc.view(10, 100, 1152, 1, 16)
1 Like

I see. The broadcasting mechanism is applied implicitly here. Thanks!