Implementing multiple recomputations on top of `torch.utils.checkpoint`


My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.

I am attempting to implement memory checkpointing as done in torch.utils.checkpoint, except allowing for multiple recomputations. However, there a couple things from the implementation that I’m not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.

  1. In CheckpointFunction's backward, why detach the input before recomputing the forward?

    Would you get something like the gradient of the forward being accumulated twice into the input’s gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?

  2. Why does CheckpointFunction's backward return (None, None) + grads? Where grads are the grads wrt to the inputs. I really am confused here.

  3. CheckpointFunction’s forward produces an output that does not require grad, so surely the next layer’s backward will not be set to propagate to it’s backward?

    Say I have a model whose forward performs h(checkpoint(g, f(x))), which I’ll equivalently write as f -> checkpoint(g) -> h. Say x has requires_grad=True. Then, my understanding of autograd is that something along the lines of the following will occur:

    1. f's forward will see that its input requires grad and thus produce an output that requires grad too (and whose gradient function is f's backward function). It will save the relevant tensors for the backward and set the backward to propagate the gradient back to its input’s gradient function.

    2. The output will be fed into CheckpointFunction's forward. It will behave similar to as described above, except for this (from the source code):

      with torch.no_grad():
        outputs = run_function(*args)
      return outputs

      This code means that output will have requires_grad=False, and that no backward gradient graph will have been created.

    3. Thus, when this output is fed into h's forward - even if h contains parameters that require grad meaning the output will require grad too - it will not set up the backward graph to propagate gradient to output's backward function, which is CheckpointFunction's backward. This would mean CheckpointFunction's backward will not get invoked on a call to h.backward(), and so neither will g.backward() or f.backward().

  4. Implementing multiple recomputations by checkpoint()-ing a model whose child model is checkpoint()-ed too will result in the child saving its input in the first forward pass, not the second (first recomputation pass) as intended.

    I was trying to implement multiple recomputations by trivially building on the existing checkpoint function, for example by creating a module that performs 1 -> drop(2) -> 3, where the module 2 itself performs 2a -> drop(2b), where drop is like a higher-order model whose forward simply performs checkpoint(child_model, x). Thus, running 2 should drop 2a and 2b in the forward pass; and in the backward pass, recompute 2a, recompute 2b and drop it, then recompute 2b a second time, this time actually saving it. Obviously, that’s not so smart but its a simple example of multiple recomputations: all of 2 is dropped in the first forward pass, and when we are recomputing it, 2a is checkpointed and 2b is dropped again.

    Or that is the intent, but I believe the following chain of events will occur in practice:

    1. 1's forward is performed, whose output is propagated to drop(2)'s forward function.

    2. drop(2)'s forward invokes CheckpointFunction's forward, which saves (checkpoints) the input, and then invokes 2's forward without tracking gradient. Again, the intent is that both the outputs of 2a and 2b will be dropped at this stage.

    3. 2a performs its forward and the output is passed to drop(2b)'s forward.

    4. This invokes CheckpointFunction's forward which will save the input to 2b thus checkpointing it, which, as mentioned in step 2, I do not want to happen!

      On the other hand, maybe it will be freed because of how everything is set up and autograd’s reference counting?

      Maybe, as drop(2b) saves its input, it propagates forward with no grad, and 2a itself was run with no grad; you get an unreachable reference cycle between them that will get garbage collected?

      I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.

    If my anaylsis is right, I will have to implement this drop operator from scratch such that it avoids this behaviour, correct?

Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.


Note it’s been a while since I last looked at this. Also, I think the source for torch.utils.CheckpointFunction might be slightly different now, though hopefully not in a way that affects this question.

  1. Say we didn’t detach the input, when we recompute the output and run the backward, the backward propagation would propagate past the input, as it is still attached to the gradient graph from the prior computations that created it. However, the function we are implementing (CheckpointFunction's backward) is supposed to be just one operator in that gradient graph, and return just its backward, so we do not want to cause that side effect.

    With respect to whether detaching duplicates the underlying buffers, I don’t remember what I found but it should be simple to verify by experimentation (the tensor objects have a pointer to the underlying buffer; make a detached copy and check them for equality).

  2. This was actually really simply and took far too long for me to figure out. The API of Autograd is assuming that the backward function returns one gradient for each argument given to the forward function, excluding the first context argument. If we look at the corresponding forward of CheckpointFunction, we see there are three inputs, the first two being ‘pseudo inputs’ that configure the checkpointing rather than being actual tensors, but we still need to account for them when returning the gradients from the backward function, hence (None, None) + grads.

  3. I’m afraid I can’t recall how I figured this out. Sorry.

  4. I think my analysis here was correct; the tensors to be checkpointed only on the first recomputation were being checkpointing immediately.

    To get around this, I added a recomputation_depth parameter to the forward function of CheckpointFunction, and only checkpointed the input if the depth was 0. The forward of the Drop module carries its own recomputation_depth parameter and, in its forward, decrements it before invoking the checkpoint function.

    Note that, as explained in 2., we now must return (None, None, None) + grads in the backward, due to the additional argument given to the forward.

    This successfully allows you to encode multiple recomputations using these Drop modules, for example by doing:

            2)),          # <----- recursion depth of inner drop
        1                 # <----- recursion depth of outer drop

    where layer 1 is not recomputed; 2 and 5 are recomputed once; and 3 and 4 are recomputed twice.

Hope this helps.

Is there a method to swap in&out data using pytorch? As far as I know, recomputation will result in a bit of computation efficiency loss, so, is it possible using swap in&out instead in pytorch?

I haven’t looked much into PyTorch implementations for swapping, but I do recall there may have been some work on implementing something similar to TensorFlow Large Model Support (TFLMS), which was a swapping solution for TensorFlow.

A quick google search gave me this thread: Thoughts on use of CPU ram as a swap for GPU

Original TFLMS Paper: