Using sksparse.cholmod.cholesky for solving sparse systems

Hi!

I use torch.optim to optimize a function which actively uses sparse matrix operations. The current torch.sparse package does not support operations like computing Cholesky factorizations of sparse tensors, computing log-determinants or solving sparse linear systems. Therefore it would be very beneficial if I could find some way of implementing these operations as custom autograd functions.

From previous experience, I have found that sksparse.cholmod.cholesky works very well for sparse Cholesky decomposition of scipy.sparse matrices. This also supports computations of the log-determinant and can be used to solve systems. I therefore want to use this within my autograd function.

Following this tutorial (Flaport.net | Creating a Pytorch solver for sparse linear systems) I have the following code that solves a sparse system using the scipy.sparse.linalg.spsolve solver.

import torch
import numpy as np
import scipy.sparse as sparse

class Solve(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, b):
        A_np = A.data.numpy()
        b_np = b.data.numpy()
        A_sp = sparse.csc_matrix(A_np)
        
        # Solve using spsplve
        x_np = sparse.linalg.spsolve(A_sp, b_np)

        x = torch.tensor(x_np, requires_grad=True)
        ctx.save_for_backward(A, b, x)
        return x

    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradb = Solve.apply(A.T, grad)
        gradA = -gradb[:, None] * x[None, :]
        return gradA, gradb

custom_solve_system = Solve.apply

# Test the solver
A = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
A = (A + A.t()) / 2 + torch.eye(3) # Make SPD
b = torch.randn(3, requires_grad=True, dtype=torch.float64)

# Returns True
torch.autograd.gradcheck(custom_solve_system, [A.to_dense(), b])

According to the gradcheck function, the custom_solve_system function works as intended. However, now I replace the spsolve with a solver using the sksparse.cholmod.cholesky. This apparently causes problems.

import torch
import numpy as np
import scipy.sparse as sparse
from sksparse.cholmod import cholesky

class Solve2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, b):
        A_np = A.data.numpy()
        b_np = b.data.numpy()
        A_sp = sparse.csc_matrix(A_np)

        # Instead use the sksparse Cholesky factorization
        A_chol = cholesky(A_sp)
        x_np = A_chol.solve_A(b_np)

        x = torch.tensor(x_np, requires_grad=True)
        ctx.save_for_backward(A, b, x)
        return x

    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradb = Solve2.apply(A.T, grad)
        gradA = -gradb[:, None] * x[None, :]
        return gradA, gradb

custom_solve_system_2 = Solve2.apply

# Test the solver
A = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
A = (A + A.t()) / 2 + torch.eye(3) # Make SPD
b = torch.randn(3, requires_grad=True, dtype=torch.float64)

# Raises an error.
torch.autograd.gradcheck(custom_solve_system_2, [A.to_dense(), b])

Since this gives an error, the function no longer works according to the gradcheck. Does anybody know how I can fix this? Any help would be much appreciated!