Hi,
My masters thesis is on making neural nets use less memory. One technique I am looking at is memory checkpointing. I can solve for the optimal policy (including multiple recomputations), given the memory budget and per-operator compute/memory costs.
I am attempting to implement memory checkpointing as done in torch.utils.checkpoint
, except allowing for multiple recomputations. However, there a couple things from the implementation that I’m not quite sure I understand. Apologies if anything is obvious, I have been using PyTorch for <2 days.
-
In
CheckpointFunction
’s backward, why detach the input before recomputing the forward?Would you get something like the gradient of the forward being accumulated twice into the input’s gradient? Also, does this detaching not duplicate the input resulting in higher than necessary memory cost?
-
Why does
CheckpointFunction
’s backward return(None, None) + grads
? Wheregrads
are the grads wrt to the inputs. I really am confused here. -
CheckpointFunction’s forward produces an output that does not require grad, so surely the next layer’s backward will not be set to propagate to it’s backward?
Say I have a model whose forward performs
h(checkpoint(g, f(x)))
, which I’ll equivalently write asf -> checkpoint(g) -> h
. Sayx
hasrequires_grad=True
. Then, my understanding of autograd is that something along the lines of the following will occur:-
f
’s forward will see that its input requires grad and thus produce an output that requires grad too (and whose gradient function isf
’s backward function). It will save the relevant tensors for the backward and set the backward to propagate the gradient back to its input’s gradient function. -
The output will be fed into
CheckpointFunction
’s forward. It will behave similar to as described above, except for this (from the source code):with torch.no_grad():
outputs = run_function(*args)
return outputsThis code means that
output
will haverequires_grad=False
, and that no backward gradient graph will have been created. -
Thus, when this
output
is fed intoh
’s forward - even ifh
contains parameters that require grad meaning the output will require grad too - it will not set up the backward graph to propagate gradient tooutput
’s backward function, which isCheckpointFunction
’s backward. This would meanCheckpointFunction
’s backward will not get invoked on a call toh.backward()
, and so neither willg.backward()
orf.backward()
.
-
-
Implementing multiple recomputations by checkpoint()-ing a model whose child model is checkpoint()-ed too will result in the child saving its input in the first forward pass, not the second (first recomputation pass) as intended.
I was trying to implement multiple recomputations by trivially building on the existing
checkpoint
function, for example by creating a module that performs1 -> drop(2) -> 3
, where the module2
itself performs2a -> drop(2b)
, wheredrop
is like a higher-order model whose forward simply performscheckpoint(child_model, x)
. Thus, running2
should drop2a
and2b
in the forward pass; and in the backward pass, recompute2a
, recompute2b
and drop it, then recompute2b
a second time, this time actually saving it. Obviously, that’s not so smart but its a simple example of multiple recomputations: all of2
is dropped in the first forward pass, and when we are recomputing it,2a
is checkpointed and2b
is dropped again.Or that is the intent, but I believe the following chain of events will occur in practice:
-
1
’s forward is performed, whose output is propagated todrop(2)
's forward function. -
drop(2)
's forward invokesCheckpointFunction
’s forward, which saves (checkpoints) the input, and then invokes2
’s forward without tracking gradient. Again, the intent is that both the outputs of2a
and2b
will be dropped at this stage. -
2a
performs its forward and the output is passed todrop(2b)
's forward. -
This invokes
CheckpointFunction
’s forward which will save the input to2b
thus checkpointing it, which, as mentioned in step 2, I do not want to happen!On the other hand, maybe it will be freed because of how everything is set up and autograd’s reference counting?
Maybe, as
drop(2b)
saves its input, it propagates forward with no grad, and2a
itself was run with no grad; you get an unreachable reference cycle between them that will get garbage collected?I really lack the autograd understanding to know (see the <2 days of PyTorch). I also do not know how to profile this to observe if it drops the tensor or not.
If my anaylsis is right, I will have to implement this
drop
operator from scratch such that it avoids this behaviour, correct? -
Thank you for making it this far and sorry if the above explanations are not great, it would be easier with diagrams. Any help on this would be greatly appreciated.