My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.
I am attempting to implement memory checkpointing as done in
torch.utils.checkpoint, except allowing for multiple recomputations. However, there a couple things from the implementation that I’m not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.
CheckpointFunction's backward, why detach the input before recomputing the forward?
Would you get something like the gradient of the forward being accumulated twice into the input’s gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?
CheckpointFunction's backward return
(None, None) + grads? Where
gradsare the grads wrt to the inputs. I really am confused here.
CheckpointFunction’s forward produces an output that does not require grad, so surely the next layer’s backward will not be set to propagate to it’s backward?
Say I have a model whose forward performs
h(checkpoint(g, f(x))), which I’ll equivalently write as
f -> checkpoint(g) -> h. Say
requires_grad=True. Then, my understanding of autograd is that something along the lines of the following will occur:
f's forward will see that its input requires grad and thus produce an output that requires grad too (and whose gradient function is
f's backward function). It will save the relevant tensors for the backward and set the backward to propagate the gradient back to its input’s gradient function.
The output will be fed into
CheckpointFunction's forward. It will behave similar to as described above, except for this (from the source code):
outputs = run_function(*args)
This code means that
requires_grad=False, and that no backward gradient graph will have been created.
Thus, when this
outputis fed into
h's forward - even if
hcontains parameters that require grad meaning the output will require grad too - it will not set up the backward graph to propagate gradient to
output's backward function, which is
CheckpointFunction's backward. This would mean
CheckpointFunction's backward will not get invoked on a call to
h.backward(), and so neither will
Implementing multiple recomputations by checkpoint()-ing a model whose child model is checkpoint()-ed too will result in the child saving its input in the first forward pass, not the second (first recomputation pass) as intended.
I was trying to implement multiple recomputations by trivially building on the existing
checkpointfunction, for example by creating a module that performs
1 -> drop(2) -> 3, where the module
2a -> drop(2b), where
dropis like a higher-order model whose forward simply performs
checkpoint(child_model, x). Thus, running
2bin the forward pass; and in the backward pass, recompute
2band drop it, then recompute
2ba second time, this time actually saving it. Obviously, that’s not so smart but its a simple example of multiple recomputations: all of
2is dropped in the first forward pass, and when we are recomputing it,
2ais checkpointed and
2bis dropped again.
Or that is the intent, but I believe the following chain of events will occur in practice:
1's forward is performed, whose output is propagated to
drop(2)'s forward function.
drop(2)'s forward invokes
CheckpointFunction's forward, which saves (checkpoints) the input, and then invokes
2's forward without tracking gradient. Again, the intent is that both the outputs of
2bwill be dropped at this stage.
2aperforms its forward and the output is passed to
CheckpointFunction's forward which will save the input to
2bthus checkpointing it, which, as mentioned in step 2, I do not want to happen!
On the other hand, maybe it will be freed because of how everything is set up and autograd’s reference counting?
drop(2b)saves its input, it propagates forward with no grad, and
2aitself was run with no grad; you get an unreachable reference cycle between them that will get garbage collected?
I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.
If my anaylsis is right, I will have to implement this
dropoperator from scratch such that it avoids this behaviour, correct?
Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.