Additional information for backward?

I am writing a resampled convolution layer that requires a large sampling grid. The same grid can be used in all of my convolutions, so I just want to have one copy in memory and pass a reference to it to the convolution in the forward pass. That works fine. The memory issue arises due to the entire sampling grid being saved for backward computation. It ends up making copies of my grid for every forward pass–this drains my memory very quickly. Because the sampling grid is identical for forward and backward, I want to write a custom backward function that takes the grid as an argument. This way I only need to keep that one copy in memory. Is it possible to pass additional information into backward()?

I don’t want the grid to be a class member of the module because I don’t want each copy of the module to store the gird. I pass it in as an extra argument to forward().


If you pass it as a tensor that do not require gradients to the forward and, it with ctx.save_for_backward and get it back with ctx.saved_tensors in the backward, then it should just keep a reference to it, not duplicate it.
Could you share the code of a minimal Function that would reproduce this behavior?

Thank you for clarifying. It turns out that you are right and my memory issue is due the gradient computation itself, not the grid being copied.

1 Like