I have a custom autograd.Function, where in the forward, I solve a sparse linear system on the input vector, using a prefactorization from scipy.sparse.
In the backward, the gradient requires to solve the same sparse linear system with the input gradient as right hand side, so I need to access the prefactor object in the backward.
- What is the best way to achieve this ?
Below is a MWE of what I’ve been doing until now (storing into ctx.intermediate, and deleting it in the backward), but it has a downside:
- I can’t use
retain_graph=True
when doing backward(), because ctx.intermediate gets deleted the first time, so the second time it goes throught the backward function, the result is not good anymore.
If only I could have access to the variableretain_graph
in that scope, I could delete ctx.intermediate only if retain_graph=False, but AFAIK I don’t.
class MyCustomFunction(torch.autograd.Function):
@staticmethod
def precomputation(weights):
# Build the sparse matrix from the weights: M is a sparse scipy matrix in CSC format.
M = compute_sparse_matrix_from_weight(weights)
# Decompose using LU: Mpre is an object of type scipy.sparse.linalg.SuperLU
Mpre = scipy.sparse.linalg.splu(M)
return Mpre
@staticmethod
def forward(ctx, weights, b, Mpre):
# Solve the system with b as the right hand side
res = Mpre.solve(b)
# Save data for backward
ctx.intermediate = (res, weights, Mpre)
# ctx.save_for_backward(res, weights, Mpre) # Doesn't work for non torch tensors arguments
return prm.from_numpy(res)
@staticmethod
def backward(ctx, grad_output):
res, weights, Mpre = ctx.intermediate
# Set initial grads
grad_weights = grad_rhs = None
# M^{-1}g is needed for computing the backward
grad_res = Mpre.solve(grad_output)
if ctx.needs_input_grad[0]:
grad_weights = compute_grad_wrt_weights(res, grad_res, grad_output, weights)
if ctx.needs_input_grad[1]:
grad_rhs = grad_res # Derivative wrt the RHS is directly M^{-1}g
# This deletes what we stored in ctx. If we don't do that, these objects are not freed, and memory grows and grows.
# We could have use ctx.save_for_backward(), but it only takes arguments that are torch.tensor...
# The only problem with this, is that doing .backward(retain_graph=True) doesn't work, because ctx.intermediate
# gets deleted after the first backward, and I have no access to the variable retain_graph in this function, to
# only delete ctx.intermediate when retain_graph=False
del ctx.intermediate
return grad_weights, None, grad_rhs, None, None
# Example of how I use that function in an optimization process:
weights = init_weights()
# Start optimization
for iter in range(100):
weights.requires_grad_(True)
M, Mpre = MyCustomFunction.precomputation(weights)
def xi(x):
return MyCustomFunction.apply(weights,x,Mpre)
# Many many calls to this function
uk = torch.zeros(N)
for i in range(1000):
uk1 = xi(uk)
uk = uk1
loss = f(uk)
loss.backward()
grad_weights = weights.grad.clone()
# Update the weights
with torch.no_grad:
weights -= lr*grad_weights
weights.grad.zero_()