Just to let everybody know that’s how I solved.
The solution, unfortunately, was to implement my own simple batched cholesky ( th.cholesky(..., upper=False)
) and then deal with Nan values using th.isnan
.
import torch as th
# nograd cholesky
def cholesky(A):
L = th.zeros_like(A)
for i in range(A.shape[-1]):
for j in range(i+1):
s = 0.0
for k in range(j):
s = s + L[...,i,k] * L[...,j,k]
L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
(1.0 / L[...,j,j] * (A[...,i,j] - s))
return L
Using this code on here I also implemented this on C++ Libtorch and worked flawless.