Performing torch.bmm with distinct batch sizes

I need to perform batched matrix multiplication between two matrices A and B where DIM(A) = [m1, n, o] and B = [m2, n, o]. Where the output C

C = A x B

has the same shape as A. Now, the batch dimension of A & B is supposed to be different. Hence the resulting product C = A x B is performed like, for example,

c[0:2] = A[0:2, :, :] x B[0, :, :]T

c[2:6] = A[2:6, :, :] x B[1, :, :]T

c[6:8] = A[6:8, :, :] x B[2, :, :]T

The map between indices of m1 dim to indices of m2 dim is known upfront. Only one index in the m2 dim is mapped to one or multiple indices in the m1 dim. This is essentially running distinct matmul ops with different sizes in parallel. The naive way of implementing this in PyTorch would be to use a for-loop but that would not execute these matmuls in parallel. What is the best way to implement this in PyTorch to get the best out of GPU performance or is a custom Cuda kernel must?

Can someone please help? Experts here @ptrblck