Hi there,
I am trying to write a function for rank-1 update with its corresponding backpropagation in PyTorch as it is not originally supported. However, I find that it is slower than torch.linalg.cholesky
. The function I wrote is shown as follows.
@torch.jit.script
def cholesky_rank1_update_jit(L, v):
L_new = L.detach().clone()
n = L.shape[0]
v = v.clone() # Clone v to avoid modifying the original vector
for k in range(n):
Lkk = L_new[k, k]
vk = v[k]
r = torch.sqrt(Lkk**2 + vk**2)
c = r / Lkk
s = vk / Lkk
L_new[k, k] = r
if k + 1 < n:
L_ik = L_new[k+1:n, k]
v_i = v[k+1:n]
L_new[k+1:n, k] = (L_ik + s * v_i) / c
v[k+1:n] = c * v_i - s * L_ik
return L_new
Can I have any idea to improve the performance of this code?