I don’t understand why this code snippet gives the output it gives. Can you explain the math behind this or post a link to someplace explaining higher dimensional matrix multiplication? Thanks!

```
A = torch.randn(1, 64, 1152, 1, 8)
B = torch.randn(10, 1, 1152, 8, 16)
C = A @ B
print(C.size())
```

output: torch.Size([10, 64, 1152, 1, 16])