When should you save_for_backward vs. storing in ctx

I’m defining a new function using the 0.2 style and am wondering when it is appropriate to store intermediate results in the ctx object as opposed to using the save_for_backward function.

Here is a simple example:

# OPTION 1
class Square(Function):
    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
        c = a + b
        return c * c

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_variables
        c = a + b

        grad = grad_output * (2 * c)
        return grad, grad


# OPTION 2
class Square(Function):
    @staticmethod
    def forward(ctx, a, b):
        c = a + b
        ctx.intermediate_results = c
        return c * c

    @staticmethod
    def backward(ctx, grad_output):
        c = ctx.intermediate_results

        grad = grad_output * (2 * c)
        return grad, grad

What are the advantages/disadvantages to Option 1 over Option 2? When would you want to use one over the other?

(In a real example, the computation of the intermediate variable “c” would be a lot more intensive, e.g. involve O(n^2) operations)

Thank you.

3 Likes

I think you can also store results.
My impression is that it is “use save_for_backward if there isn’t a reason not to”. Efficiency is a reason.

Best regards

Thomas

Ok, but is there a reason why you should use save_for_backward? How much more efficient does it have to be to justify storing directly to ctx?

I don’t think there are hard rules about it, so I would look at the torch.autograd._functions for inspiration when to use which and see what looks similar to your case.
For example in the Clamp Function you store the mask in the ctx while for the square root a few lines down, you just store the input.

Best regards

Thomas

Thanks, Thomas. Looking through the source code it seems like the main advantage to save_for_backward is that the saving is done in C rather python. So it seems like anytime the computation for the intermediate result would allocate python storage space anyway, it makes sense to use the ctx object.

1 Like

I think that option 1 delete the gpu memory of the Variable.

there will be memory leak in option 2, where ctx.intermediate_results won’t be released.

Hu? The ctx would go out of scope after the backward, too, would it not?

Best regards

Thomas

We experienced a memory leak on a custom module using custom-make GPU kernels. Storing intermediate results in ctx.x, ctx.z etc resulted in them not being released over multiple mini-batches, quickly exhausting GPU memory. We did not experience this in pytorch 0.3.1, it sounds like it is due to some changes in pytorch 0.4. Either explicitly del ctx.x etc or adding x,z to .save_for_backward both fixed the memory leak. I would still like to get a good explanation for why you would want to use save_for_backward vs. directly assigning to ctx.

I seem to recall that in earlier torch versions it was not possible to save_for_backward an intermediate result, only inputs to the forward function, but that seems to be no longer the case.

8 Likes