I am working on a huge model where I am very much memory constrained. I continuously generate terms for regularisation, my loss looks something like:

`loss = accuracy + sum_ penalty alpha*penalty`

my computational graph looks like this:

``````computation -> computation -> computation -> computation -> computation -> (...)
|                          |
-> penalty                 -> penalty                                                       (...)
``````

since any further computation is independent of the penalty, i could calculate the gradients for the penalty eagerly. So letâ€™s say my function is called `calc_penalty(tensor)`, what I would like is to immediately calculate the gradient wrt the input so that every tensor used in the computation of `calc_penalty` can be freed (even the input-tensor).

Hi Leander!

Let me tell you what I think. (I havenâ€™t tested or experimented with any
of this, but I think it will work.)

First a comment on autograd terminology: The â€śleavesâ€ť of the computation
graph are the tensors that have `requires_grad = True`, arenâ€™t derived from
anything else (at least not from anything that has `requires_grad = True`),
and are the tensors with respect to which you wish to compute the gradient.
The â€śrootâ€ť (or possibly roots) of the computation graph is the final loss value
whose gradient you wish to compute and on which you call `.backward()`.

Using this terminology, I think you might have meant â€śRoots,â€ť rather than

If I understand correctly, I believe that you can do this. I think that the following
pseudo-code sketch should work (but to repeat, I havenâ€™t tried this):

``````# comp1
comp1 = f1 (stuff)

# penalty1, with some explicit intermediate results
p1a = g1a (comp1)
p1b = g1b (p1a)
penalty1 = g1c (p1b)

penalty1.backward (retain_graph = True)   # preserve graph through f1 (as well as g1c, g1b, and g1a)
penalty1 = None   # release penalty1 (and then release intermediate results)
p1b = None        # set references holding on to penalty1-specific graph to None
p1a = None        # should release section of graph from penalty1 back through g1c, g1b, and g1a

# comp2
comp2 = f2 (comp1)

# penalty2, with some explicit intermediate results
p2a = g2a (comp2)
p2b = g2b (p2a)
penalty2 = g2c (p2b)

penalty2.backward (retain_graph = True)   # preserve graph back through f2 and f1 (as well as g2c, g2b, and g2a)
penalty2 = None   # release penalty2 (and then release intermediate results)
p2b = None        # set references holding on to penalty2-specific graph to None
p2a = None        # should release section of graph from penalty2 back through g2c, g2b, and g2a

# comp3 and final loss
comp3 = f3 (comp2)
loss = loss_fn (comp3)

# accumulate final loss grad (and release remaining graph)
loss.backward()   # backpropagates through loss_fn, f3, f2, and f1 (and further back through stuff)
``````

The idea is two-fold: First, you call `.backward (retain_graph = True)`
â€śeagerlyâ€ť on your penalty terms, using `retain_graph = True` to preserve
the computation graph through the â€śbackboneâ€ť of the graph. Second, as I
understand it, the references in the computation graph that keep tensors in
memory start in the downstream roots (e.g., final loss value and your penalties)
and refer to nodes further upstream towards the leaves (e.g., the parameters
you are training). Therefore, if you set certain of your own root references to
`None`, everything in the section of the graph between that deleted root
reference and the still-referred-to backbone can have its memory freed.
Furthermore, that â€śpenalty-sectionâ€ť of the graph will be available to have
its memory freed even though you specified `retain_graph = True` when
you called `penalty.backward()`.

Anyway, thatâ€™s my theory â€¦

Yes, I believe that this will work, with the proviso that if your â€śinput-tensorâ€ť
is needed to perform some part of the backbone backward pass, the
backbone will automatically maintain a reference to it, so it wonâ€™t be freed
(until your final `loss.backward()` is called).

Best.

K. Frank

1 Like

I am sorry for taking so long, I didnâ€™t come to implemented it until now because other things were more complicated than I thought.

I need a bit more control over the gradients. As I understood it, `backward` computes the whole tree. How can I only compute it until the input of the penalty (`comp1` and `comp2`?)? I need to manipulate the gradient again after computing it, thereâ€™s a scaling for each sample I now compute at the end of the network.

Hi Leander!

I donâ€™t understand your use case, but let me make a couple of comments:

If you need to backpropagate the gradient of `penalty1` back through only
the computation of `penalty1` (and not through the computation of `comp1`)
then simply, `.clone().detach()` `comp1`, set `requires_grad = True` on
the `detach()`ed copy, compute `penalty1`, and call `penalty1.backward()`.

But I imagine that you need to backpropagate `penalty1`â€™s gradient,
perhaps in modified form, back through the computation of `comp1` as
well, presumably in order to update model parameters that are used
to compute `comp1`.

In this case, you could `.clone().detach()`, backpropagate just back to
the end of `comp1`, and save the resulting gradient, at a cost in memory,
for modification and further processing at some later time.

But it seems to me that you would be better off performing the full
backpropagation of `penalty1` all at once, modifying its gradient â€śon
the fly.â€ť

To do this, you can write a custom autograd function whose `forward()`
method just passes `comp1` through unchanged, but whose `backward()`
method modifies the gradient of `penalty1` according to your desire before
passing it on back upstream through the computation of `comp1`.

Best.

K. Frank

I have not implemented it yet, since I am still struggling with other parts, but to me this seems to be the correct idea! It might be hard to understand my use-case, itâ€™s quite special since the tensors my network operates on are in the gigabytes, so very very large. I currently use gradient-checkpointing heavily, but thatâ€™s just super slow.

Update: I have implemented it and it works! If somebody wants to do something similar, I would advise to look at how checkpoints are implemented

As I understood it, `backward` computes the whole tree. How can I only compute it until the input of the penalty (`comp1` and `comp2` ?)?

FYI you can specify the inputs you want to backward to with `.backward(inputs=)` and it will only backprop through the parts of your graph that are needed to compute gradients for those specified inputs.