Issue description
The official support for batched inversion has not been released, so I have coded a snippet for the operation.
Is this implementation using LU decomposition right? And any solution for optimizing this code? Hope it could help anyone who needs this.
Code
Below is my implementation, and one can also refer to the gist.
def inv(A, eps = 1e-10):
assert len(A.shape) == 3 and \
A.shape[1] == A.shape[2]
n = A.shape[1]
U = A.clone().data
L = A.new_zeros(A.shape).data
L[:, range(n), range(n)] = 1
I = L.clone()
# A = LU
# [A I] = [LU I] -> [U L^{-1}]
L_inv = I
for i in range(n-1):
L[:, i+1:, i:i+1] = U[:, i+1:, i:i+1] / (U[:, i:i+1, i:i+1] + eps)
L_inv[:, i+1:, :] = L_inv[:, i+1:, :] - L[:, i+1:, i:i+1].matmul(L_inv[:, i:i+1, :])
U[:, i+1:, :] = U[:, i+1:, :] - L[:, i+1:, i:i+1].matmul(U[:, i:i+1, :])
# [U L^{-1}] -> [I U^{-1}L^{-1}] = [I (LU)^{-1}]
A_inv = L_inv
for i in range(n-1, -1, -1):
A_inv[:, i:i+1, :] = A_inv[:, i:i+1, :] / (U[:, i:i+1, i:i+1] + eps)
U[:, i:i+1, :] = U[:, i:i+1, :] / (U[:, i:i+1, i:i+1] + eps)
if i > 0:
A_inv[:, :i, :] = A_inv[:, :i, :] - U[:, :i, i:i+1].matmul(A_inv[:, i:i+1, :])
U[:, :i, :] = U[:, :i, :] - U[:, :i, i:i+1].matmul(U[:, i:i+1, :])
A_inv_grad = - A_inv.matmul(A).matmul(A_inv)
return A_inv + A_inv_grad - A_inv_grad.data