When should you save_for_backward vs. storing in ctx

We experienced a memory leak on a custom module using custom-make GPU kernels. Storing intermediate results in ctx.x, ctx.z etc resulted in them not being released over multiple mini-batches, quickly exhausting GPU memory. We did not experience this in pytorch 0.3.1, it sounds like it is due to some changes in pytorch 0.4. Either explicitly del ctx.x etc or adding x,z to .save_for_backward both fixed the memory leak. I would still like to get a good explanation for why you would want to use save_for_backward vs. directly assigning to ctx.

I seem to recall that in earlier torch versions it was not possible to save_for_backward an intermediate result, only inputs to the forward function, but that seems to be no longer the case.

8 Likes