# Is there a way to simplify this Mahalanobis function for tensors? I would like to rewrite it in C++ but it is incomprehensible

What would be the most simplified way of writing this large function below assuming my bL matrix is 42x42 and my bx matrix is 45x42?
The large (second) function below returns a 1D tensor of length 45 and I would like the same thing returned, but I would like this function written with as few lines as possible because I will always be using a bL of 42x42 and a bx of nx42 and I need to rewrite it in C++ if you were wondering.

First, my attempt at condensing.

tril = torch.rand(42, 42)
value = torch.rand(45, 42)
loc = torch.rand(42)
diff = value - loc # [45, 42]

m = torch.dot(diff, torch.matmul(torch.inverse(tril), diff))
mahalanobis = torch.sqrt(m)


It returns the error mat1 and mat2 shapes cannot be multiplied (42x42 and 45x42)

Below is the mahalanobis distance function from the PyTorch source code.

def _batch_mahalanobis(bL, bx):
r"""
Computes the squared Mahalanobis distance :math:\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}
for a factored :math:\mathbf{M} = \mathbf{L}\mathbf{L}^\top.

Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
shape, but bL one should be able to broadcasted to bx one.
"""
n = bx.size(-1)
bx_batch_shape = bx.shape[:-1]

# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
bx_batch_dims = len(bx_batch_shape)
bL_batch_dims = bL.dim() - 2
outer_batch_dims = bx_batch_dims - bL_batch_dims
old_batch_dims = outer_batch_dims + bL_batch_dims
new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = bx.shape[:outer_batch_dims]
for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (n,)
bx = bx.reshape(bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (list(range(outer_batch_dims)) +
list(range(outer_batch_dims, new_batch_dims, 2)) +
list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
[new_batch_dims])
bx = bx.permute(permute_dims)

flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
M_swap = torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)  # shape = b x c
M = M_swap.t()  # shape = c x b

# Now we revert the above reshape and permute operators.
permuted_M = M.reshape(bx.shape[:-1])  # shape = (..., 1, j, i, 1)
permute_inv_dims = list(range(outer_batch_dims))
for i in range(bL_batch_dims):
permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
reshaped_M = permuted_M.permute(permute_inv_dims)  # shape = (..., 1, i, j, 1)
return reshaped_M.reshape(bx_batch_shape)


I’m not able to rewrite all of these things in another language and would just like to extract the most simplified code here for matrices of a fixed size.