Implementing multiple recomputations on top of `torch.utils.checkpoint`


My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.

I am attempting to implement memory checkpointing as done in torch.utils.checkpoint, except allowing for multiple recomputations. However, there a couple things from the implementation that I’m not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.

  1. In CheckpointFunction's backward, why detach the input before recomputing the forward?

    Would you get something like the gradient of the forward being accumulated twice into the input’s gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?

  2. Why does CheckpointFunction's backward return (None, None) + grads? Where grads are the grads wrt to the inputs. I really am confused here.

  3. CheckpointFunction’s forward produces an output that does not require grad, so surely the next layer’s backward will not be set to propagate to it’s backward?

    Say I have a model whose forward performs h(checkpoint(g, f(x))), which I’ll equivalently write as f -> checkpoint(g) -> h. Say x has requires_grad=True. Then, my understanding of autograd is that something along the lines of the following will occur:

    1. f's forward will see that its input requires grad and thus produce an output that requires grad too (and whose gradient function is f's backward function). It will save the relevant tensors for the backward and set the backward to propagate the gradient back to its input’s gradient function.

    2. The output will be fed into CheckpointFunction's forward. It will behave similar to as described above, except for this (from the source code):

      with torch.no_grad():
        outputs = run_function(*args)
      return outputs

      This code means that output will have requires_grad=False, and that no backward gradient graph will have been created.

    3. Thus, when this output is fed into h's forward - even if h contains parameters that require grad meaning the output will require grad too - it will not set up the backward graph to propagate gradient to output's backward function, which is CheckpointFunction's backward. This would mean CheckpointFunction's backward will not get invoked on a call to h.backward(), and so neither will g.backward() or f.backward().

  4. Implementing multiple recomputations by checkpoint()-ing a model whose child model is checkpoint()-ed too will result in the child saving its input in the first forward pass, not the second (first recomputation pass) as intended.

    I was trying to implement multiple recomputations by trivially building on the existing checkpoint function, for example by creating a module that performs 1 -> drop(2) -> 3, where the module 2 itself performs 2a -> drop(2b), where drop is like a higher-order model whose forward simply performs checkpoint(child_model, x). Thus, running 2 should drop 2a and 2b in the forward pass; and in the backward pass, recompute 2a, recompute 2b and drop it, then recompute 2b a second time, this time actually saving it. Obviously, that’s not so smart but its a simple example of multiple recomputations: all of 2 is dropped in the first forward pass, and when we are recomputing it, 2a is checkpointed and 2b is dropped again.

    Or that is the intent, but I believe the following chain of events will occur in practice:

    1. 1's forward is performed, whose output is propagated to drop(2)'s forward function.

    2. drop(2)'s forward invokes CheckpointFunction's forward, which saves (checkpoints) the input, and then invokes 2's forward without tracking gradient. Again, the intent is that both the outputs of 2a and 2b will be dropped at this stage.

    3. 2a performs its forward and the output is passed to drop(2b)'s forward.

    4. This invokes CheckpointFunction's forward which will save the input to 2b thus checkpointing it, which, as mentioned in step 2, I do not want to happen!

      On the other hand, maybe it will be freed because of how everything is set up and autograd’s reference counting?

      Maybe, as drop(2b) saves its input, it propagates forward with no grad, and 2a itself was run with no grad; you get an unreachable reference cycle between them that will get garbage collected?

      I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.

    If my anaylsis is right, I will have to implement this drop operator from scratch such that it avoids this behaviour, correct?

Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.

1 Like