Understanding batch multiplication using torch.matmul

The bullet point about batch matrix multiplication in the documentation of torch.matmul mentions the following statement:
"The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if tensor1 is a (j×1×n×m) tensor and tensor2 is a (k×m×p) tensor, out will be an (j×k×n×p) tensor. "

In this statement, it is not clear for me how are non-matrix dimensions identified. For example if I have two matrices A and B of sizes (1000, 500, 100, 10) and (500, 10, 50) respectively. In this case, what will the dimension of matrices that are multiplied? How many multiplications will be done in the batch?



The matrix multiplication is always done with using the last two dimensions. All the ones before are considered as batch.
In your case the matrix multiplications will be of size 100x10 and 10x50. The batch dimensions are 1000x500 and 500 and so will be broadcasted to 1000x500. The final output will thus be of size 1000x500x100x50.


Thanks, that clarifies things for me.