Memory used by `autograd` when `torch.scatter` is involved

I am trying to implement gradient computation (fine-tuning) for large context LLMs in the presence of sparse inference and KV caching. In this context, the following arises.

Denote KV cache buffer content by cache_t which has (simplified) shape (batch, cache_len), where cache_len is the number of slots in the cache, and batch is everything else, e.g. batch size, embedding dimension. A large context is processed sequentially, which means that during multi-head attention (MHA), cache_t is mapped to cache_(t+1) by writing KV content for new tokens into the cache, overwriting previous content. This is done as

cache_(t+1) = torch.scatter(cache_t, -1, index, new_kv)

where index maps to positions to be overwritten, and new_kv is the new content.

The new content new_kv has shape (batch, new_len), where new_len is much smaller than cache_len. And then, cache_t takes part in MLA computations being the key and value tensors the query tensor for the t-th sequential step is interacting with. If you know token-by-token generation, this would correspond to new_len = 1, but when the whole input is given (during training), we can choose new_len > 1.

My question is: If the computation graph looks like this, what memory does autograd store during the forward pass in training mode? Say there are T sequential steps.

  • It could be simple, but wasteful: autograd stores all cache_t, so T times batch * cache_len.
  • It could be a little more difficult, but use less memory: autograd stores the first cache_0, but then only two tensors of size batch * new_len for each t, namely new_kv and the content that is overwritten, so torch.gather(cache_t, -1, index). Obviously, I can obtain cache_(t+1) from cache_t if I know these two. This would be one time batch * cache_len plus T times 2 * batch * new_len.

Can somebody help me?

scatter is linear wrt cache_t so that does not need to be saved for backward by autograd

unrelated question: isn’t KV caching something you’d typically use inference. Interesting that you need autograd here?

Thanks. So, even if the graph uses these cache_t in further linear operations, they’ll not be stored? With further linear, I mean that cache_t contains keys and values, as input into MHA, so keys is linearly combined from some queries (much smaller), and values is combined with attention weights (also much smaller).

This is great.

I am trying to figure out how to differentiate through a certain version (with some blocked gradients) of the computation graph, where KV caches are used inside. The main issue is memory usage, due to forward storing intermediate activations. This is why I am asking.

In the end, you need to fine-tune models with large context widths, right? There should be solutions to this already, but I am trying to find something more generic.

So, even if the graph uses these cache_t in further linear operations,

ahh not exactly, for these bilinear types of operations like linear/conv in order to compute grads wrt the weight, you still need to save the activations.

There should be solutions to this already, but I am trying to find something more generic.

Generally for activation memory, there’s torch.utils.checkpoint — PyTorch 2.6 documentation btw

Thanks. I am already planning to use activation checkpointing. But if all the cache_t are going to be stored, this may still be too much.

Can you think of a way to figure this out? I have the code. Is there some hook mechanism which allows me to find out which auxiliary tensors are created in the training mode forward pass?

Hello, I found Hooks for autograd saved tensors — PyTorch Tutorials 2.6.0+cu124 documentation. With this, I see a (low-level) way to do what I want to do.

But I am really struggling understanding how torch.utils.checkpoint works. I am facing an instance of nested checkpointing.

Do you happen to know of any complete worked-through non-trivial example of nested checkpointing done with torch.utils.checkpoint? Thanks!