# Storing intermediate data that are not tensors

I have a custom autograd.Function, where in the forward, I solve a sparse linear system on the input vector, using a prefactorization from scipy.sparse.
In the backward, the gradient requires to solve the same sparse linear system with the input gradient as right hand side, so I need to access the prefactor object in the backward.

• What is the best way to achieve this ?

Below is a MWE of what I’ve been doing until now (storing into ctx.intermediate, and deleting it in the backward), but it has a downside:

• I can’t use `retain_graph=True` when doing backward(), because ctx.intermediate gets deleted the first time, so the second time it goes throught the backward function, the result is not good anymore.
If only I could have access to the variable `retain_graph` in that scope, I could delete ctx.intermediate only if retain_graph=False, but AFAIK I don’t.
``````

@staticmethod
def precomputation(weights):

# Build the sparse matrix from the weights: M is a sparse scipy matrix in CSC format.
M = compute_sparse_matrix_from_weight(weights)
# Decompose using LU: Mpre is an object of type scipy.sparse.linalg.SuperLU
Mpre = scipy.sparse.linalg.splu(M)
return Mpre

@staticmethod
def forward(ctx, weights, b, Mpre):

# Solve the system with b as the right hand side
res = Mpre.solve(b)

# Save data for backward
ctx.intermediate = (res, weights, Mpre)
# ctx.save_for_backward(res, weights, Mpre)   # Doesn't work for non torch tensors arguments

return prm.from_numpy(res)

@staticmethod

res, weights, Mpre = ctx.intermediate

# M^{-1}g is needed for computing the backward

# This deletes what we stored in ctx. If we don't do that, these objects are not freed, and memory grows and grows.
# We could have use ctx.save_for_backward(), but it only takes arguments that are torch.tensor...
# The only problem with this, is that doing .backward(retain_graph=True) doesn't work, because ctx.intermediate
# gets deleted after the first backward, and I have no access to the variable retain_graph in this function, to
# only delete ctx.intermediate when retain_graph=False
del ctx.intermediate

# Example of how I use that function in an optimization process:
weights = init_weights()

# Start optimization
for iter in range(100):

M, Mpre = MyCustomFunction.precomputation(weights)

def xi(x):
return MyCustomFunction.apply(weights,x,Mpre)

# Many many calls to this function
uk = torch.zeros(N)
for i in range(1000):
uk1 = xi(uk)
uk = uk1

loss = f(uk)
loss.backward()

# Update the weights

``````

Hi,

You should use `save_for_backward()` for any input or output and `ctx.` for everything else.

``````# In forward
ctx.res = res
ctx.save_for_backward(weights, Mpre)

# In backward
res = ctx.res
weights, Mpre = ctx.saved_tensors
``````

If you do that, you won’t need to do `del ctx.intermediate`. It is the fact that you save some inputs in the ctx that create the memory leak.
And so you will be able to `retain_graph=True`.

I cannot do `save_for_backward(weights, Mpre)` because Mpre is not a tensor, and I get the error:

TypeError: save_for_backward can only save variables, but argument 1 is of type SuperLU

That is why I used ctx.intermediate.
When you say “for any input or output”, is it I/O of the forward function ?

Ho ok, in that case you can save `Mpre` like `res`. But keep `weights` in save_for_backward to avoid the memory leak.

for any input or output

Yes I mean input and outputs of the `forward()` function.
And as you mentionned, it is input or output Tensors only. Other types can be saved in `ctx.`