PyTorch equivalent of scipy.sparse.linalg.gmres

Hi all,

Is there an equivalent of scipy.sparse.linalg.gmres within the PyTorch framework? In JAX, they seem to have a direct equivalent with jax.scipy.sparse.linalg.gmres, and I was wondering if there’s a way of doing this within PyTorch?

I’ve had a look and can’t find anything, but I may have missed it! If there’s a way of doing this in PyTorch, please do let me know!