Gradient Checkpointing: More Checkpoints, Higher Memory Usage

So I was playing around trying to learn gradient checkpointing. I found an interesting behavior that does not match my understanding of the paper I found that there was a sweet spot for the number of checkpoints and going beyond that memory would increase. I found the exact same behavior with checkpoint_sequential and checkpoint. Here’s a link to my code with checkpoint_sequential (go back a dir for the non-sequential). I am working on a 2080 Super and I just made a overly simple linear model (linear, relus, and end with a sigmoid) such that it barely doesn’t fit into memory (if you want code in here I can replicate with a network that will look nicer). I get the following results

| Num Splits | SMI Memory |
| 2 | 7212MB |
| 4 | 6000MB |
| 8 | 5428MB |
| 10 | 5810MB |
| 16 | 6190MB |

The non-sequential version has similar results (one splitting actually uses 7908MB!). cuda.max_memory_allocated()shows the same trend (but smaller numbers) but cuda.memory_allocated() does not.

So my understanding is that the more checkpoints I do the lower my memory usage should be. Rather my results show something very different and reminds me more of how there are optimal number of CPUs to use in parallel processing (where overhead starts to use too many resources). So:

  • Do I have a bug somewhere?
  • Am I misunderstanding the paper?
    • If so, how do I find the optimal number of checkpoints?
  • Do I have a bug somewhere?

I don’t think so.

So my understanding is that the more checkpoints I do the lower my memory usage should be.

This is not true no. The checkpointing only saves the input/output and is able to free all the buffers in the middle. But if the content of the checkpoint is too small, then there are not intermediary buffers to free and so you won’t see any memory improvement.

If so, how do I find the optimal number of checkpoints?

It is hard to say as it has to balance not recomputing too big a chunk (that uses a lot of memory) with not having too small a chunk (for which no intermediary buffer can be freed).
You will most likely have to see it experimentally for your particular model.

1 Like

Thanks! This makes sense to me but didn’t match my understanding of the paper.

Are there any best practices/intuition you or others can offer?

Hi,

I am afraid I don’t have a very large practical experience. But for vision models, it is common to checkpoint the conv/relu/norm blocks. You can check here how this is used for the torchvision models.

Have you discovered the reason for increasing memory when using more splits?
Maybe with more splits, you get more checkpoints to keep in memory, and that causes the memory to increase?