# When should you save_for_backward vs. storing in ctx

I’m defining a new function using the 0.2 style and am wondering when it is appropriate to store intermediate results in the ctx object as opposed to using the save_for_backward function.

Here is a simple example:

``````# OPTION 1
class Square(Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
c = a + b
return c * c

@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_variables
c = a + b

grad = grad_output * (2 * c)
return grad, grad

# OPTION 2
class Square(Function):
@staticmethod
def forward(ctx, a, b):
c = a + b
ctx.intermediate_results = c
return c * c

@staticmethod
def backward(ctx, grad_output):
c = ctx.intermediate_results

grad = grad_output * (2 * c)
return grad, grad

``````

What are the advantages/disadvantages to Option 1 over Option 2? When would you want to use one over the other?

(In a real example, the computation of the intermediate variable “c” would be a lot more intensive, e.g. involve O(n^2) operations)

Thank you.

3 Likes

I think you can also store results.
My impression is that it is “use `save_for_backward` if there isn’t a reason not to”. Efficiency is a reason.

Best regards

Thomas

Ok, but is there a reason why you should use save_for_backward? How much more efficient does it have to be to justify storing directly to ctx?

I don’t think there are hard rules about it, so I would look at the torch.autograd._functions for inspiration when to use which and see what looks similar to your case.
For example in the Clamp Function you store the mask in the `ctx` while for the square root a few lines down, you just store the input.

Best regards

Thomas

Thanks, Thomas. Looking through the source code it seems like the main advantage to save_for_backward is that the saving is done in C rather python. So it seems like anytime the computation for the intermediate result would allocate python storage space anyway, it makes sense to use the ctx object.

1 Like

I think that option 1 delete the gpu memory of the Variable.

there will be memory leak in option 2, where ctx.intermediate_results won’t be released.

Hu? The ctx would go out of scope after the backward, too, would it not?

Best regards

Thomas

We experienced a memory leak on a custom module using custom-make GPU kernels. Storing intermediate results in `ctx.x, ctx.z` etc resulted in them not being released over multiple mini-batches, quickly exhausting GPU memory. We did not experience this in pytorch 0.3.1, it sounds like it is due to some changes in pytorch 0.4. Either explicitly `del ctx.x` etc or adding x,z to `.save_for_backward` both fixed the memory leak. I would still like to get a good explanation for why you would want to use `save_for_backward` vs. directly assigning to `ctx`.

I seem to recall that in earlier torch versions it was not possible to `save_for_backward` an intermediate result, only inputs to the `forward` function, but that seems to be no longer the case.

7 Likes