Gradient checkpointing and its effect on memory and runtime

Hello,

I am trying to understand how the number of checkpoints in gradient checkpointing affects the memory and runtime for computing gradients.

I found this notebook that explains how gradient checkpointing works. I modified it a bit and ran a couple of experiments to see how the use_reentrant and the segments arguments affect the memory and runtime.

I was surprised to find the following results

  • For memory, the data shows almost no improvement when checkpointing is enabled compared to no checkpointing, no matter the number of checkpoints and whether we are using reentrant. From this thread, I understand that the optimal number of checkpoints depends on the specific scenario. However, I still expected at least some improvement in the memory consumption. I’m not quite sure why this is?
  • For runtime, the experiment shows the runtime increases as the number of checkpoints increases. However, from the Chen2016 [1] paper, my understanding is the runtime should roughly increase by 30% when checkpointing is enabled due to the extra forward pass, and should remain roughly the same regardless the number of checkpoints. Is the increasing runtime due to the extra overhead required for keeping more checkpoints?

Bellow are my results and their corresponding graphs. Note segments=0 means no checkpointing is enabled.

Any insight will help, and thanks in advance :slight_smile:

use_reentrant segments Runtime (s) Current Memory (B) Peak Memory (B)
FALSE 0 109.1 413527 154106538
TRUE 0 109.1 413527 154106538
FALSE 1 109.12 402938 154105009
TRUE 1 110.14 394474 154121963
FALSE 2 116.23 382560 154127403
TRUE 2 114.86 378951 154162031
FALSE 5 116.82 428681 154214720
TRUE 5 117.23 507778 154334766
FALSE 10 119.44 384482 154244424
TRUE 10 118.84 515236 154438154

[1] Chen, Tianqi, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. “Training deep nets with sublinear memory cost.” arXiv preprint arXiv:1604.06174 (2016).

Activation checkpoint produces more pronounced reductions in peak memory when your activations are larger relative to model size. What do those numbers look like for your example?

Is the increasing runtime due to the extra overhead required for keeping more checkpoints?

In the checkpoint sequential the last segment is not checkpointed. The fewer segments you have, the larger that non-checkpointed region, and the smaller the runtime.

Thanks for your reply!

Yes there are dropout layers in the model so the neuron units are not always active if that’s what you mean. I’ve removed the dropout layers and tried again, but the memory trend is not much different than before. The peak memory I have before training is 146.98MB, with checkpointing enabled and segments=1, the memory consumption is 146.94MB, with segments=2 the memory consumption is 146.96MB. It doesn’t seem to improve the performance much, but I might be misunderstanding something here.

In the checkpoint sequential the last segment is not checkpointed.

That makes a lot more sense. Do you know if checkpoint and checkpoint_sequential behave the same way in this aspect? I.e. if I use checkpoint with fewer segments, should I also expect a smaller runtime because the last segment is not checkpointed?

I’ve also had experiments where checkpointing a function actually decreased the runtime compared to the uncheckpointed version. Do you have any idea why that could be?