Eager Autograd on Leaves

I am working on a huge model where I am very much memory constrained. I continuously generate terms for regularisation, my loss looks something like:

loss = accuracy + sum_ penalty alpha*penalty

my computational graph looks like this:

computation -> computation -> computation -> computation -> computation -> (...)
                    |                          |                      
                    -> penalty                 -> penalty                                                       (...)

since any further computation is independent of the penalty, i could calculate the gradients for the penalty eagerly. So let’s say my function is called calc_penalty(tensor), what I would like is to immediately calculate the gradient wrt the input so that every tensor used in the computation of calc_penalty can be freed (even the input-tensor).

Hi Leander!

Let me tell you what I think. (I haven’t tested or experimented with any
of this, but I think it will work.)

First a comment on autograd terminology: The “leaves” of the computation
graph are the tensors that have requires_grad = True, aren’t derived from
anything else (at least not from anything that has requires_grad = True),
and are the tensors with respect to which you wish to compute the gradient.
The “root” (or possibly roots) of the computation graph is the final loss value
whose gradient you wish to compute and on which you call .backward().

Using this terminology, I think you might have meant “Roots,” rather than
“Leaves,” in your title.

If I understand correctly, I believe that you can do this. I think that the following
pseudo-code sketch should work (but to repeat, I haven’t tried this):

# comp1
comp1 = f1 (stuff)

# penalty1, with some explicit intermediate results
p1a = g1a (comp1)
p1b = g1b (p1a)
penalty1 = g1c (p1b)

# accumulate penalty1 grad
penalty1.backward (retain_graph = True)   # preserve graph through f1 (as well as g1c, g1b, and g1a)
penalty1 = None   # release penalty1 (and then release intermediate results)
p1b = None        # set references holding on to penalty1-specific graph to None
p1a = None        # should release section of graph from penalty1 back through g1c, g1b, and g1a

# comp2
comp2 = f2 (comp1)

# penalty2, with some explicit intermediate results
p2a = g2a (comp2)
p2b = g2b (p2a)
penalty2 = g2c (p2b)

# accumulate penalty2 grad
penalty2.backward (retain_graph = True)   # preserve graph back through f2 and f1 (as well as g2c, g2b, and g2a)
penalty2 = None   # release penalty2 (and then release intermediate results)
p2b = None        # set references holding on to penalty2-specific graph to None
p2a = None        # should release section of graph from penalty2 back through g2c, g2b, and g2a

# comp3 and final loss
comp3 = f3 (comp2)
loss = loss_fn (comp3)

# accumulate final loss grad (and release remaining graph)
loss.backward()   # backpropagates through loss_fn, f3, f2, and f1 (and further back through stuff)

The idea is two-fold: First, you call .backward (retain_graph = True)
“eagerly” on your penalty terms, using retain_graph = True to preserve
the computation graph through the “backbone” of the graph. Second, as I
understand it, the references in the computation graph that keep tensors in
memory start in the downstream roots (e.g., final loss value and your penalties)
and refer to nodes further upstream towards the leaves (e.g., the parameters
you are training). Therefore, if you set certain of your own root references to
None, everything in the section of the graph between that deleted root
reference and the still-referred-to backbone can have its memory freed.
Furthermore, that “penalty-section” of the graph will be available to have
its memory freed even though you specified retain_graph = True when
you called penalty.backward().

Anyway, that’s my theory …

Yes, I believe that this will work, with the proviso that if your “input-tensor”
is needed to perform some part of the backbone backward pass, the
backbone will automatically maintain a reference to it, so it won’t be freed
(until your final loss.backward() is called).

Best.

K. Frank