How do I checkpoint the last layer using `torch.utils.checkpoint`?

I will use a very simple example to explain my problem. Say I have a 6-layer feedforward model represented as an array [1, 2, 3, 4, 5, 6], and I want to divide it into 3 segments, i.e. [1,2], [3,4], [5,6], for gradient-checkpointing. Given the comment at line516 of torch.utils.checkpoint_sequential as below, PyTorch will only call the torch.utils.checkpoint for the first two segments [1,2] [3,4] and the remaining two layers 5, 6 normally with their output data being saved for backward. With this design, I don’t know how to save the output data of the last layer 6 for backward phase.

Say I call torch.utils.checkpoint for the segment [5, 6] too, is it possible to write some extra code to make sure that the result of layer 6 will be saved for backward? Or maybe I don’t need to worry about this, because PyTorch will always save the output data of the last layer to compute the first gradient (the partial derivative of the last layer output w.r.t. the output of the last layer)?