Eager Autograd on Leaves

I am working on a huge model where I am very much memory constrained. I continuously generate terms for regularisation, my loss looks something like:

loss = accuracy + sum_ penalty alpha*penalty

my computational graph looks like this:

computation -> computation -> computation -> computation -> computation -> (...)
                    |                          |                      
                    -> penalty                 -> penalty                                                       (...)

since any further computation is independent of the penalty, i could calculate the gradients for the penalty eagerly. So let’s say my function is called calc_penalty(tensor), what I would like is to immediately calculate the gradient wrt the input so that every tensor used in the computation of calc_penalty can be freed (even the input-tensor).

Hi Leander!

Let me tell you what I think. (I haven’t tested or experimented with any
of this, but I think it will work.)

First a comment on autograd terminology: The “leaves” of the computation
graph are the tensors that have requires_grad = True, aren’t derived from
anything else (at least not from anything that has requires_grad = True),
and are the tensors with respect to which you wish to compute the gradient.
The “root” (or possibly roots) of the computation graph is the final loss value
whose gradient you wish to compute and on which you call .backward().

Using this terminology, I think you might have meant “Roots,” rather than
“Leaves,” in your title.

If I understand correctly, I believe that you can do this. I think that the following
pseudo-code sketch should work (but to repeat, I haven’t tried this):

# comp1
comp1 = f1 (stuff)

# penalty1, with some explicit intermediate results
p1a = g1a (comp1)
p1b = g1b (p1a)
penalty1 = g1c (p1b)

# accumulate penalty1 grad
penalty1.backward (retain_graph = True)   # preserve graph through f1 (as well as g1c, g1b, and g1a)
penalty1 = None   # release penalty1 (and then release intermediate results)
p1b = None        # set references holding on to penalty1-specific graph to None
p1a = None        # should release section of graph from penalty1 back through g1c, g1b, and g1a

# comp2
comp2 = f2 (comp1)

# penalty2, with some explicit intermediate results
p2a = g2a (comp2)
p2b = g2b (p2a)
penalty2 = g2c (p2b)

# accumulate penalty2 grad
penalty2.backward (retain_graph = True)   # preserve graph back through f2 and f1 (as well as g2c, g2b, and g2a)
penalty2 = None   # release penalty2 (and then release intermediate results)
p2b = None        # set references holding on to penalty2-specific graph to None
p2a = None        # should release section of graph from penalty2 back through g2c, g2b, and g2a

# comp3 and final loss
comp3 = f3 (comp2)
loss = loss_fn (comp3)

# accumulate final loss grad (and release remaining graph)
loss.backward()   # backpropagates through loss_fn, f3, f2, and f1 (and further back through stuff)

The idea is two-fold: First, you call .backward (retain_graph = True)
“eagerly” on your penalty terms, using retain_graph = True to preserve
the computation graph through the “backbone” of the graph. Second, as I
understand it, the references in the computation graph that keep tensors in
memory start in the downstream roots (e.g., final loss value and your penalties)
and refer to nodes further upstream towards the leaves (e.g., the parameters
you are training). Therefore, if you set certain of your own root references to
None, everything in the section of the graph between that deleted root
reference and the still-referred-to backbone can have its memory freed.
Furthermore, that “penalty-section” of the graph will be available to have
its memory freed even though you specified retain_graph = True when
you called penalty.backward().

Anyway, that’s my theory …

Yes, I believe that this will work, with the proviso that if your “input-tensor”
is needed to perform some part of the backbone backward pass, the
backbone will automatically maintain a reference to it, so it won’t be freed
(until your final loss.backward() is called).

Best.

K. Frank

1 Like

I am sorry for taking so long, I didn’t come to implemented it until now because other things were more complicated than I thought.

I need a bit more control over the gradients. As I understood it, backward computes the whole tree. How can I only compute it until the input of the penalty (comp1 and comp2?)? I need to manipulate the gradient again after computing it, there’s a scaling for each sample I now compute at the end of the network.

Hi Leander!

I don’t understand your use case, but let me make a couple of comments:

If you need to backpropagate the gradient of penalty1 back through only
the computation of penalty1 (and not through the computation of comp1)
then simply, .clone().detach() comp1, set requires_grad = True on
the detach()ed copy, compute penalty1, and call penalty1.backward().

But I imagine that you need to backpropagate penalty1’s gradient,
perhaps in modified form, back through the computation of comp1 as
well, presumably in order to update model parameters that are used
to compute comp1.

In this case, you could .clone().detach(), backpropagate just back to
the end of comp1, and save the resulting gradient, at a cost in memory,
for modification and further processing at some later time.

But it seems to me that you would be better off performing the full
backpropagation of penalty1 all at once, modifying its gradient “on
the fly.”

To do this, you can write a custom autograd function whose forward()
method just passes comp1 through unchanged, but whose backward()
method modifies the gradient of penalty1 according to your desire before
passing it on back upstream through the computation of comp1.

Best.

K. Frank

I have not implemented it yet, since I am still struggling with other parts, but to me this seems to be the correct idea! It might be hard to understand my use-case, it’s quite special since the tensors my network operates on are in the gigabytes, so very very large. I currently use gradient-checkpointing heavily, but that’s just super slow.

Update: I have implemented it and it works! If somebody wants to do something similar, I would advise to look at how checkpoints are implemented

As I understood it, backward computes the whole tree. How can I only compute it until the input of the penalty (comp1 and comp2 ?)?

FYI you can specify the inputs you want to backward to with .backward(inputs=) and it will only backprop through the parts of your graph that are needed to compute gradients for those specified inputs.