Is my batched inversion for 3-d tensor implemented right or any other optimizing solution

Issue description

The official support for batched inversion has not been released, so I have coded a snippet for the operation.

Is this implementation using LU decomposition right? And any solution for optimizing this code? Hope it could help anyone who needs this.

Code

Below is my implementation, and one can also refer to the gist.

def inv(A, eps = 1e-10):

    assert len(A.shape) == 3 and \
           A.shape[1] == A.shape[2]
    n = A.shape[1]
    U = A.clone().data
    L = A.new_zeros(A.shape).data
    L[:, range(n), range(n)] = 1
    I = L.clone()

    # A = LU
    # [A I] = [LU I] -> [U L^{-1}]
    L_inv = I
    for i in range(n-1):
        L[:, i+1:, i:i+1] = U[:, i+1:, i:i+1] / (U[:, i:i+1, i:i+1] + eps)
        L_inv[:, i+1:, :] = L_inv[:, i+1:, :] - L[:, i+1:, i:i+1].matmul(L_inv[:, i:i+1, :])
        U[:, i+1:, :] = U[:, i+1:, :] - L[:, i+1:, i:i+1].matmul(U[:, i:i+1, :])

    # [U L^{-1}] -> [I U^{-1}L^{-1}] = [I (LU)^{-1}]
    A_inv = L_inv
    for i in range(n-1, -1, -1):
        A_inv[:, i:i+1, :] = A_inv[:, i:i+1, :] / (U[:, i:i+1, i:i+1] + eps)
        U[:, i:i+1, :] = U[:, i:i+1, :] / (U[:, i:i+1, i:i+1] + eps)

        if i > 0:
            A_inv[:, :i, :] = A_inv[:, :i, :] - U[:, :i, i:i+1].matmul(A_inv[:, i:i+1, :])
            U[:, :i, :] = U[:, :i, :] - U[:, :i, i:i+1].matmul(U[:, i:i+1, :])

    A_inv_grad = - A_inv.matmul(A).matmul(A_inv)
    return A_inv + A_inv_grad - A_inv_grad.data

PyTorch has a few BLAS/LAPACK style operations implemented. torch.btrifact — for instance — could serve as a faster (and probably numerically more stable) alternative.