Storing intermediate data that are not tensors

I have a custom autograd.Function, where in the forward, I solve a sparse linear system on the input vector, using a prefactorization from scipy.sparse.
In the backward, the gradient requires to solve the same sparse linear system with the input gradient as right hand side, so I need to access the prefactor object in the backward.

  • What is the best way to achieve this ?

Below is a MWE of what I’ve been doing until now (storing into ctx.intermediate, and deleting it in the backward), but it has a downside:

  • I can’t use retain_graph=True when doing backward(), because ctx.intermediate gets deleted the first time, so the second time it goes throught the backward function, the result is not good anymore.
    If only I could have access to the variable retain_graph in that scope, I could delete ctx.intermediate only if retain_graph=False, but AFAIK I don’t.

class MyCustomFunction(torch.autograd.Function):

    @staticmethod
    def precomputation(weights):

        # Build the sparse matrix from the weights: M is a sparse scipy matrix in CSC format.
        M = compute_sparse_matrix_from_weight(weights)
        # Decompose using LU: Mpre is an object of type scipy.sparse.linalg.SuperLU
        Mpre = scipy.sparse.linalg.splu(M)
        return Mpre

    @staticmethod
    def forward(ctx, weights, b, Mpre):

        # Solve the system with b as the right hand side
        res = Mpre.solve(b)

        # Save data for backward
        ctx.intermediate = (res, weights, Mpre)
        # ctx.save_for_backward(res, weights, Mpre)   # Doesn't work for non torch tensors arguments

        return prm.from_numpy(res)

    @staticmethod
    def backward(ctx, grad_output):

        res, weights, Mpre = ctx.intermediate

        # Set initial grads
        grad_weights = grad_rhs = None

        # M^{-1}g is needed for computing the backward
        grad_res = Mpre.solve(grad_output)

        if ctx.needs_input_grad[0]:
            grad_weights = compute_grad_wrt_weights(res, grad_res, grad_output, weights)
        if ctx.needs_input_grad[1]:
            grad_rhs = grad_res     # Derivative wrt the RHS is directly M^{-1}g

        # This deletes what we stored in ctx. If we don't do that, these objects are not freed, and memory grows and grows.
        # We could have use ctx.save_for_backward(), but it only takes arguments that are torch.tensor...
        # The only problem with this, is that doing .backward(retain_graph=True) doesn't work, because ctx.intermediate
        # gets deleted after the first backward, and I have no access to the variable retain_graph in this function, to
        # only delete ctx.intermediate when retain_graph=False
        del ctx.intermediate

        return grad_weights, None, grad_rhs, None, None


# Example of how I use that function in an optimization process:
weights = init_weights()

# Start optimization
for iter in range(100):
    weights.requires_grad_(True)

    M, Mpre = MyCustomFunction.precomputation(weights)

    def xi(x):
        return MyCustomFunction.apply(weights,x,Mpre)

    # Many many calls to this function
    uk = torch.zeros(N)
    for i in range(1000):
        uk1 = xi(uk)
        uk = uk1

    loss = f(uk)
    loss.backward()
    grad_weights = weights.grad.clone()

    # Update the weights
    with torch.no_grad:
        weights -= lr*grad_weights
        weights.grad.zero_()


Hi,

You should use save_for_backward() for any input or output and ctx. for everything else.
So in your case:

# In forward
ctx.res = res
ctx.save_for_backward(weights, Mpre)

# In backward
res = ctx.res
weights, Mpre = ctx.saved_tensors

If you do that, you won’t need to do del ctx.intermediate. It is the fact that you save some inputs in the ctx that create the memory leak.
And so you will be able to retain_graph=True.

I cannot do save_for_backward(weights, Mpre) because Mpre is not a tensor, and I get the error:

TypeError: save_for_backward can only save variables, but argument 1 is of type SuperLU

That is why I used ctx.intermediate.
When you say “for any input or output”, is it I/O of the forward function ?

Ho ok, in that case you can save Mpre like res. But keep weights in save_for_backward to avoid the memory leak.

for any input or output

Yes I mean input and outputs of the forward() function.
And as you mentionned, it is input or output Tensors only. Other types can be saved in ctx.