When trying to work out bottlenecks in my code I found that the _batch_mahalanobis
function from torch.distributions.multivariate_normal
was quite slow.
I decided to compare the most recent release of the function (from master branch) with a previous version, and the latest version is notably slower (on my machine almost 4x slower).
Here are the results of my script that compares the runtime and results of the two methods:
[New] Time Taken: 0.677909s
[Old] Time Taken: 0.159045s
Average Relative Error: 7.209917385342379e-11
Comparison Script
import time
import torch
import numpy as np
def _batch_trtrs_lower(bb, bA):
"""
Applies `torch.trtrs` for batches of matrices. `bb` and `bA` should have
the same batch shape.
"""
flat_b = bb.reshape((-1,) + bb.shape[-2:])
flat_A = bA.reshape((-1,) + bA.shape[-2:])
flat_X = torch.stack([torch.trtrs(b, A, upper=False)[0] for b, A in zip(flat_b, flat_A)])
return flat_X.reshape(bb.shape)
def _batch_mahalanobis_old(bL, bx):
r"""
Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
shape, but `bL` one should be able to broadcasted to `bx` one.
"""
n = bx.size(-1)
bL = bL.expand(bx.shape[bx.dim() - bL.dim() + 1:] + (n,))
flat_L = bL.reshape(-1, n, n) # shape = b x n x n
flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
M_swap = _batch_trtrs_lower(flat_x_swap, flat_L).pow(2).sum(-2) # shape = b x c
return M_swap.t().reshape(bx.shape[:-1])
def _batch_mv(bmat, bvec):
r"""
Performs a batched matrix-vector product, with compatible but different batch shapes.
This function takes as input `bmat`, containing :math:`n \times n` matrices, and
`bvec`, containing length :math:`n` vectors.
Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
to a batch shape. They are not necessarily assumed to have the same batch shape,
just ones which can be broadcasted.
"""
return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
def _batch_mahalanobis(bL, bx):
r"""
Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
shape, but `bL` one should be able to broadcasted to `bx` one.
"""
n = bx.size(-1)
bx_batch_shape = bx.shape[:-1]
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
bx_batch_dims = len(bx_batch_shape)
bL_batch_dims = bL.dim() - 2
outer_batch_dims = bx_batch_dims - bL_batch_dims
old_batch_dims = outer_batch_dims + bL_batch_dims
new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = bx.shape[:outer_batch_dims]
for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (n,)
bx = bx.reshape(bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (list(range(outer_batch_dims)) +
list(range(outer_batch_dims, new_batch_dims, 2)) +
list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
[new_batch_dims])
bx = bx.permute(permute_dims)
flat_L = bL.reshape(-1, n, n) # shape = b x n x n
flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c
M = M_swap.t() # shape = c x b
# Now we revert the above reshape and permute operators.
permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
permute_inv_dims = list(range(outer_batch_dims))
for i in range(bL_batch_dims):
permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
return reshaped_M.reshape(bx_batch_shape)
input_1 = torch.from_numpy(np.linalg.cholesky(np.diag(np.random.rand(2).astype(np.float32)))).cuda()
input_2 = torch.from_numpy(np.random.rand(13440, 2).astype(np.float32)).cuda()
runs = 1000
total_time_new, total_time_old = 0, 0
relative_error_cum = 0
for _ in range(runs):
torch.cuda.synchronize()
st = time.perf_counter()
m_new = _batch_mahalanobis(input_1, input_2)
torch.cuda.synchronize()
total_time_new += time.perf_counter() - st
st = time.perf_counter()
m_old = _batch_mahalanobis_old(input_1, input_2)
torch.cuda.synchronize()
total_time_old += time.perf_counter() - st
relative_error_cum = torch.norm(m_new-m_old) / torch.norm(m_new)
print(f'[New] Time Taken: {total_time_new:.6f}s')
print(f'[Old] Time Taken: {total_time_old:.6f}s')
print(f'Average Relative Error: {relative_error_cum/runs}')
I was wondering if anyone could provide any insight into why this might be the case?
The difference between the results of the two methods seems to be negligible (as seen in the relative error), so I can’t see any advantage of using the most recent version right now.