Force not to cache generated static tensor

Hi,
My model has a weight tensor W, which contains duplicate entries due to symmetry.
I compacted the weight tensor as U which is 1/4 the size of W, and generate an index tensor B (same size as W) to deflate it back to W on the fly, as W = U[B], note that B’s content is static/fixed.
I was hoping that since B is generated on the fly, we can save some memory, but training in Pytorch actually uses more memory.
What is the reason behind this? If index B is cache because it can be predetermined, can I force it to be always generated on the fly?
Thank you very much!

I’m not sure to understand the use case correctly.
Based on the code it seems you are creating the entire W tensor as well as the additional index tensor B and are also holding U in memory. I would guess you might save some memory before creating B and W, but would use more memory once the weight is created or where would the memory saving be coming from?

Hi,
Let’s say a layer has input X, and produces output Y.
In the original form, we need to store at each layer a weight matrix W of size m*m.
In the compact form, we need to store at each layer a weight matrix U of size (m/2)*(m/2), and on-the-fly generate a deflated version W = U[B]. We use W to generate the layer output Y, and W & B can be discarded. Back propagation only need to retain U & Y at each layer on the graph.
Not sure if my understanding of autograd is correct, or is there a way to get this behavior?
Thanks!

Thanks for the description as it fits my understanding. The peak memory would be increased, since you would have to store U, B, W, and Y at one point during the forward pass. Since Autograd needs to backpropagate through Y, the intermediates will be kept.
If you have a way to delete W and recalculate it in the backward pass, you try to could implement a custom autograd.Function with this logic.

I came across torch.utils.checkpoint, wondering if it would solve the problem.


Edit: I confirm that checkpoint solves the problem, at the expense of more computational time.