Cholesky rank-1 update in PyTorch

Hi there,

I am trying to write a function for rank-1 update with its corresponding backpropagation in PyTorch as it is not originally supported. However, I find that it is slower than torch.linalg.cholesky. The function I wrote is shown as follows.

@torch.jit.script
def cholesky_rank1_update_jit(L, v):
    L_new = L.detach().clone()
    n = L.shape[0]
    v = v.clone()  # Clone v to avoid modifying the original vector
    for k in range(n):
        Lkk = L_new[k, k]
        vk = v[k]
        r = torch.sqrt(Lkk**2 + vk**2)
        c = r / Lkk
        s = vk / Lkk
        L_new[k, k] = r
        if k + 1 < n:
            L_ik = L_new[k+1:n, k]
            v_i = v[k+1:n]
            L_new[k+1:n, k] = (L_ik + s * v_i) / c
            v[k+1:n] = c * v_i - s * L_ik
    return L_new

Can I have any idea to improve the performance of this code?