I am writing a resampled convolution layer that requires a large sampling grid. The same grid can be used in all of my convolutions, so I just want to have one copy in memory and pass a reference to it to the convolution in the forward pass. That works fine. The memory issue arises due to the entire sampling grid being saved for backward computation. It ends up making copies of my grid for every forward pass–this drains my memory very quickly. Because the sampling grid is identical for forward and backward, I want to write a custom backward function that takes the grid as an argument. This way I only need to keep that one copy in memory. Is it possible to pass additional information into backward()
?
I don’t want the grid to be a class member of the module because I don’t want each copy of the module to store the gird. I pass it in as an extra argument to forward()
.