Torch.utils.checkpoint.checkpoint_sequential does not checkpoint the last module

I am trying to use torch.utils.checkpoint.checkpoint_sequential. However, it seems this does not checkpoint the last module (the last element in functions). This is made clear in in the docs [1] (" all segments except the last will not store the intermediate activations") and source code [2] (“the last chunk has to be non-volatile”). However, I don’t understand why this would be necessary.

[1] torch.utils.checkpoint — PyTorch 2.1 documentation

I think it is assuming that you will not be doing any further computation before computing the loss and backwarding. This would mean that you’ll need those activations right away anyway, and so we wouldn’t want to drop them and have to recompute them immediately thereafter.

btw if you are interested, we would accept a PR clarifying the docs here

This doesn’t make much sense to me. This last chunk that is not checkpointed is one (or more) module(s), which might have many internal activations that will be stored for backward.

Besides, the ‘last chunk has to be non-volatile’ sound more imperative that a design decision, yet I do not understand why this has to be the case.

As a user, when calling checkpoint_sequential, I expect all my modules to be checkpointed. It sounds a lot more resonable to me that if a user does not what their last chunk to be checkpointed, they exclude it from the arguments of checkpoint_sequential.

Sure there are many internal activations, but they are treated as a unit because we do a single checkpoint call per chunk, so at some point between (1) original forward computation of the last chunk, (2) beginning of the backward of that last chunk. You’re going to have materialize every single activation in that chunk whether or not you checkpoint it or not, so might as well not do the extra recomputation.

Hmm that seems reasonable, but unfortunately it would be bc-breaking to change the semantics of checkpoint_sequential today.

This is against the logic set by torch.utils.checkpoint.checkpoint, which does not itself store the output tensor. So one would expect the sequential version to act the same way:

  1. Checkpoint everything the user passes
  2. Do not store the outputs for backward, expecting a posterior operation to do so before calling backwards.
    I understand that this would break backwards compatibility. But I understand now that this is not imperative, and it is a design decision which frankly seems pretty inconsistent and not user friendly.