Gradient checkpointing

I wonder how torch.utils.checkpoint is supposed to be used in practice. The outputs of the checkpointing function are potentially anyway stored elsewhere for backprop, as those are not under the checkpoint anymore. E.g. consider the example:

y = (a + b) * c

And now checkpoint around the a + b:

a = ...
b = ...
c = ...
x = checkpoint(lambda (_a, _b): _a + _b, a, b)
y = x * c

My intention is that the result of a + b is not stored in memory but recomputed. However, in this example, it doesn’t really make sense to recompute x = a + b here, because the result is anyway stored for the backprop of x * c. Or is the intended usage of checkpoint for this example actually like:

a = ...
b = ...
c = ...
y = checkpoint(lambda (_a, _b, _c): (_a + _b) * c, a, b, c)

?

Another question: Are the explicit args to checkpoint actually necessary? Or can I just use this instead:

a = ...
b = ...
c = ...
y = checkpoint(lambda: (a + b) * c)

?
In real world, e.g. any model parameters are probably not expressed as explicit inputs but they are used within the function.

This is correct, the idea is that you’d use checkpointing on large enough regions such that the activations saved at the boundaries are relatively a small proportion of overall activations.

Are the explicit args to checkpoint actually necessary? Or can I just use this instead:

Yes, that is supported. There’s a subtle difference in when closed over buffers are cleared, but that shouldn’t matter for basic usage, i.e., forward then backward once.

1 Like

Ok, so for my example, if I want that a + b is not stored but recomputed, I would need to use it like this:

a = ...
b = ...
c = ...
y = checkpoint(lambda: (a + b) * c)

In this example, is the multiplication ... * c also recomputed then or not? Because I actually do not want to recompute that. In practice, this multiplication is likely a big costly matmul. I know exactly that I only want to recompute the a + b but nothing else. I fear that the whole API of checkpoint might do more recomputations than what I actually want, and/or not those recomputations that I actually want.

As I understand the code and documentation, use_reentrant=False is already more like what I want. For my simple example, I think it would not recompute the multiplication ... * c, right? But how would it do that actually? Looking at the code, I see that it calls recompute_fn once it needs some of the activations in the backward pass. But recompute_fn will just call the original function fn at some point, which includes the computation of the multiplication? Edit Ah, now I see the logic with _StopRecomputationError. This looks all quite hacky to me…

Yes, it would not be recomputed with use_reentrant=False via _StopRecomputationError. use_reentrant=True does not have this logic so the entire forward is always recomputed in that path.

So I think my original questions are answered then. But I wonder: In my case, and in this example, I know, I want to recompute x = a + b, but nothing else, whatever other computations are coming afterwards (which I maybe don’t even have control over, some other unrelated code). The current checkpoint API does not really allow me to specify that. Any code which comes after it must also be put into the checkpoint function, but only the part which uses this x directly.

Is this unusual that I only want to recompute a + b and nothing else? Basically, with the current checkpoint API, I specify the points (tensors) which are stored in memory. However, I want an API where I specify the tensors which are not stored.

That’s sounds like something reasonable to do. However, unfortunately there’s no automatic way to apply checkpoint to the consumers of this output.

The closest API there is to this today is selective activation checkpoint torch.utils.checkpoint — PyTorch main documentation (landed very recently, available in nightlies or if you build from source).

With that you can pass in a policy of “if my op, recompute, otherwise, save”. However it will save the output of ALL ops in the region except the op that you specified, not just the the buffers that are needed for backward computation.

I wonder, if I use torch.autograd.graph.saved_tensors_hooks around my whole model, then I can exactly decide what I want to store and what to recompute. How much performance penalty would I get by using saved_tensors_hooks all the time?

I also had the idea to delay saved_tensors_hooks.__exit__ to some later point in checkpoint, specifically to register some Tensor.__del__ hook (e.g. like this) on the outputs of fn and once this hook gets called, then do the saved_tensors_hooks.__exit__. But there are many further details to make this work: E.g. I need to check new_frame.forward_completed, and any newly created tensors should be stored, just not those which are created so far. This is probably not really possible in a clean way.

Also, in saved_tensors_hooks, I somehow need to check whether a tensor depends on some of the computation from the fn or not - not sure how to do this.

How much performance penalty would I get by using saved_tensors_hooks all the time?

Good question. You do always incur an additional python round trip. But, whether this matters really depends on whether your model is overhead bound.

I wonder, if I use torch.autograd.graph.saved_tensors_hooks around my whole model

This is a good idea, and I have explored something similar in the past - the idea is that I record the dependency relations between what exactly is needed to compute during backward via a RecomputableTensor object in order to only recompute precisely what is necessary.

The idea is that I define this new RecomputableTensor which doesn’t actually hold data but knows how to compute itself based on other RecomputableTensors. When I save for backward, I use saved tensor hooks to save RecomputableTensor instead of plain tensor. So let’s say during backward, I want a tensor saved for backward needed for the first gradient computed, it would recursively call recompute until it sees a RecomputableTensor that has already been materialized and the way selective checkpoint would work is that we would make sure RecomputableTensors generated from those specific ops are already materialized, and so that recursion stops early.

It should also be able to save precisely what is needed, the RecomputableTensors are kept alive in 3 ways (1) the original tensors (2) other recomputable tensors (3) saved onto the graph. If a saved tensor is not kept alive by a chain of recomputable tensors rooted in a object saved onto the graph, then it would have no references keeping it alive as soon as the original tensor dies.

1 Like
1 Like

Oh right, I forgot about __torch_dispatch__. (I lack a bit the experience on what I can hock/dispatch.)

Yea, via this mechanism, I could exactly do what I need. That seems to be great!

Note, in RETURNN, for the TensorFlow backend, I have implemented this API for gradient checkpointing, where you can do sth like this:

with gradient_checkpoint_scope():
    x = a + b
y = x * c

All tensors in gradient_checkpoint_scope will get specially marked. This is using TF graph mode. I just need to record the first node/op id and the last. And then I can iterate through all ops from the scope like this:

for op_id in range(first_op_id, last_op_id):
    op = graph._nodes_by_id[op_id]
    ...

And then I overwrite the grad func of all consumers:

for op_out in op.outputs:
    assert isinstance(op_out, tf.Tensor)
    for op_ in op_out.consumers():
        _set_wrapped_grad_func(op_)

This is before I calculate the gradient (tf.gradients, similar as Torch tensor.backward()).
In my wrapped gradient function, instead of passing any of the original tensors from that gradient_checkpoint_scope to the original gradient function, I copy all the computation graph under this scope and pass the copied tensors. I don’t really need to care about adding some unnecessary computations to the computation graph, as TensorFlow only lazily calculates whatever is needed in the end.

So, I wonder if I can implement the same logic in PyTorch, i.e. the same gradient_checkpoint_scope API. Very similar to your example, I could use __torch_dispatch__ to record the computation graph and also record all tensors which are created there. If I would use saved_tensors_hooks on the whole model, I can check whenever some of the tensors come from such checkpoint scope, and if so, recompute that.

If I don’t want to have saved_tensors_hooks on the whole model, I can solve this too: I can saved_tensors_hooks.__enter__ when I enter the checkpoint scope, but don’t immediately saved_tensors_hooks.__exit__ afterwards. Instead, I can hook into Tensor.__del__ (as described above) for all tensors created in that scope, and once they all get deleted, then I can do saved_tensors_hooks.__exit__ (and also exit the __torch_dispatch__). So this stays very local then.

It gets a bit trickier to handle the RNG state correctly, and maybe also handle AMP. I guess I need to make sure to store and then later fork and reset the state correctly, and also to replay the recorded computations exactly in the same order to make sure it’s all deterministic. But this is similar to the current Torch checkpoint logic.

So, this sounds like a plan…

Yes, torch dispatch is a very useful tool. Let me know how things go, if you decide to work on it more.

RNG is indeed tricky! There’s quite a few things that don’t work well with the prototype, but unfortunately I don’t have too much time to work it these days.