Thanks for the code.
It seems you want to ignore this single sample in the calculation.
I’m not sure, but I would assume that removing this sample from the batch and call cholesky again on the cleaned batch of matrices might work.
Thank you!
Unfortunantly I would loose my peformance gain if I had to analyse which samples fail and remove those specifically. There might be more than one.
The solution, unfortunately, was to implement my own simple batched cholesky ( th.cholesky(..., upper=False) ) and then deal with Nan values using th.isnan .
import torch as th
# nograd cholesky
def cholesky(A):
L = th.zeros_like(A)
for i in range(A.shape[-1]):
for j in range(i+1):
s = 0.0
for k in range(j):
s = s + L[...,i,k] * L[...,j,k]
L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
(1.0 / L[...,j,j] * (A[...,i,j] - s))
return L
Using this code on here I also implemented this on C++ Libtorch and worked flawless.