Hello,
I am trying to save memory on a certain forward-backward computation, by realizing that a whole sequence of large tensors can be reconstructed from smaller tensors. Think x_{k+1} = f(x_k, a_k), x_k = g(x_{k+1}, b_k)
, where x_k
are large, b_k
are smaller, so I really only need to store the final x_K
and all the b_k
to reconstruct all other x_k
.
I found:
https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html
I thought this would nail it, but I am struggling now. This is because pack_hook
only just gets the (detached) tensor, and I see no way to identify whether it is one of the x_k
. While every new node in the graph gets a unique name, this is not passed to pack_hook
.
I can match the tensors by shape, but I have a number of these sequences, and I need to be able to tell them apart. And for many arguments of pack_hook
, I just want to pass them through.
Anybody has an idea here? Can you tell me why these hooks (which are just awesome) do not also obtain information that allows me to identify the node they came from? Not doing this restricts people tyo blunt use cases such as “store all tensors to CPU / disk”, while in general, we’d want to be selective to certain tensors and not care about others.
To add details. My current implementation creates certain information for nodes created in the forward pass, for example b_k
for x_k
. Ideally, I’d now collect these in a dictionary, where the key allows me to identify x_k
. And pack_hook
would use this, and (say) store b_k
instead of x_k
in the graph. I am missing an idea for what this key would be so that pack_hook
can test whether its input argument is one of them and retrieve the information.
So, if it is def pack_hook(x): ...
, I’d need to be able to test whether x
corresponds to one of the x_k
in the dictionary.
Can you tell me why these hooks (which are just awesome) do not also obtain information that allows me to identify the node they came from?
No fundamental reason, this is a reasonable request, but if I understand correctly you’re trying to use it to implement selective saving/recompute. I’m not sure it would really help you in that case. Are you planning to just rerun the user-passed function again during recompute?
The selective storing/saving can be done by nesting hook contexts, e.g. having a inner no-op context over the parts of forward you don’t want saved tensors to be stashed on cpu.
we’d want to be selective to certain tensors and not care about others.
There’s an API for selective activation checkpointing Current and New Activation Checkpointing Techniques in PyTorch | PyTorch.
It doesn’t behave optimally in certain ways, but mostly does what you want in the common case (save all matmul/compute intensive ops). It hasn’t really been tested much with nesting/recursive checkpoint however.
Hello, I am very grateful for your help. I’d in particular be interested what you mean with “inner no-op context” and “user-passed function”. I am probably not aware of some tricks here.
Let me explain where my problem is. I’d like to implement gradient computation in the presence of KV caches. Doing this directly is totally out of the question, would require way too much memory. I’ll be using a form of nested activation checkpointing (which I coded up without torch.utils.checkpoint
, it was not that hard). But there is another property I need to use here, namely that the KV cache buffers along the “prompt processing” axis (this is a sequential direction, because the prompt is too large to be processed in a single prefill operation – that is the whole point) have a very useful property: x_k = g(x_{k+1}, b_k)
. Here, x_k
are the KV cache buffers for successive chunks of the prompt being processed, and b_k
is much smaller than x_k
. This is because the number of tokens in each chunk is much smaller than the cache length.
Knowing this, I can reconstruct the whole x_k
sequence from the b_k
and the final x_K
, and this saves a lot of memory. In a sense, I can “pack” b_k
and some index info instead of x_k
.
To make this work, I need to:
- Store the
b_k
along the forward pass. I call this “annotation”
- Pack hook: Recognize the input argument is
x_k
(or is directly related, by reshape etc), and pack b_k
plus some info instead of x_k
.
- Unpack hook: Reconstruct
x_k
from the packed information
My problem is how do I detect the right annotation in the pack hook. Namely, most calls of the pack hook, I’d not do anything. I just need to identify those that correspond to the x_k
.
For now, I am implementing something pretty brittle, which identifies the x_k
by shape and also by some entries picked at random locations (so, a kind of a hash, stored at annotation). But I’d love to do something more solid, really.
Thanks for the context.
Have you considered using something like
from torch.utils.weak import WeakTensorKeyDictionary
to identify x_k
?
No-op context means a nesting a set of saved tensor hooks where the pack/unpack are lambda x: x
This means that the outer saved tensor hooks context is effectively applied selectively because autograd only sees the top-most saved tensor hooks.
User-passed function is the function you pass to checkpoint, e.g. the fn
arg in torch.utils.checkpoint(fn, x)
. Part of the issue here is that if your implementation of AC needs to recompute this function from the very beginning, selectively choosing to save not-save using hooks doesn’t actually save you any compute.
Tangentially related to your question, but can I ask how useful recursive/nesting checkpoint is important for your use case?
I’m working on a new version of AC that should make it very trivial to do a “recompute particular tensor” operation. graph-based AC · GitHub
but, it doesn’t supporting nesting AC (by nesting AC I mean as in possibly recomputing the same forward op more than once) and it seems like adding support may add a good amount of complexity.
Hello, I need nested checkpointing because if I did one level checkpointing, the checkpoint sizes would be too large (I’d like to keep the checkpoints on the CPU).