I need to perform batched matrix multiplication between two matrices A and B where DIM(A) = [**m1, n, o**] and B = [**m2, n, o**]. Where the output C

C = A x B

has the same shape as A. Now, the batch dimension of A & B is supposed to be different. Hence the resulting product C = A x B is performed like, for example,

c[0:2] = A[0:2, :, :] x B[0, :, :]T

c[2:6] = A[2:6, :, :] x B[1, :, :]T

c[6:8] = A[6:8, :, :] x B[2, :, :]T

The map between indices of **m1** dim to indices of **m2** dim is known upfront. Only one index in the **m2** dim is mapped to one or multiple indices in the **m1** dim. This is essentially running distinct matmul ops with different sizes in parallel. The naive way of implementing this in PyTorch would be to use a for-loop but that would not execute these matmuls in parallel. What is the best way to implement this in PyTorch to get the best out of GPU performance or is a custom Cuda kernel must?