I’m defining a new function using the 0.2 style and am wondering when it is appropriate to store intermediate results in the ctx object as opposed to using the save_for_backward function.
Here is a simple example:
# OPTION 1
class Square(Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
c = a + b
return c * c
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_variables
c = a + b
grad = grad_output * (2 * c)
return grad, grad
# OPTION 2
class Square(Function):
@staticmethod
def forward(ctx, a, b):
c = a + b
ctx.intermediate_results = c
return c * c
@staticmethod
def backward(ctx, grad_output):
c = ctx.intermediate_results
grad = grad_output * (2 * c)
return grad, grad
What are the advantages/disadvantages to Option 1 over Option 2? When would you want to use one over the other?
(In a real example, the computation of the intermediate variable “c” would be a lot more intensive, e.g. involve O(n^2) operations)
I don’t think there are hard rules about it, so I would look at the torch.autograd._functions for inspiration when to use which and see what looks similar to your case.
For example in the Clamp Function you store the mask in the ctx while for the square root a few lines down, you just store the input.
Thanks, Thomas. Looking through the source code it seems like the main advantage to save_for_backward is that the saving is done in C rather python. So it seems like anytime the computation for the intermediate result would allocate python storage space anyway, it makes sense to use the ctx object.
We experienced a memory leak on a custom module using custom-make GPU kernels. Storing intermediate results in ctx.x, ctx.z etc resulted in them not being released over multiple mini-batches, quickly exhausting GPU memory. We did not experience this in pytorch 0.3.1, it sounds like it is due to some changes in pytorch 0.4. Either explicitly del ctx.x etc or adding x,z to .save_for_backward both fixed the memory leak. I would still like to get a good explanation for why you would want to use save_for_backward vs. directly assigning to ctx.
I seem to recall that in earlier torch versions it was not possible to save_for_backward an intermediate result, only inputs to the forward function, but that seems to be no longer the case.