Gradient Checkpointing: More Checkpoints, Higher Memory Usage

  • Do I have a bug somewhere?

I don’t think so.

So my understanding is that the more checkpoints I do the lower my memory usage should be.

This is not true no. The checkpointing only saves the input/output and is able to free all the buffers in the middle. But if the content of the checkpoint is too small, then there are not intermediary buffers to free and so you won’t see any memory improvement.

If so, how do I find the optimal number of checkpoints?

It is hard to say as it has to balance not recomputing too big a chunk (that uses a lot of memory) with not having too small a chunk (for which no intermediary buffer can be freed).
You will most likely have to see it experimentally for your particular model.

1 Like