"preserve_rng_state" in gradient checkpoint

I am trying to use gradient checkpoint so that I can fine-tune a huge transformer model in 12 GB GPU.
I am confused about the argument preserve_rng_state.
Basically I don’t understand the following section from that official documentation.

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can cause persistent states like the RNG state to be advanced than they would without checkpointing. By default, checkpointing includes logic to juggle the RNG state such that checkpointed passes making use of RNG (through dropout for example) have deterministic output as compared to non-checkpointed passes. The logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. If deterministic output compared to non-checkpointed passes is not required, supply preserve_rng_state=False to checkpoint or checkpoint_sequential to omit stashing and restoring the RNG state during each checkpoint.

The stashing logic saves and restores the RNG state for the current device and the device of all cuda Tensor arguments to the run_fn . However, the logic has no way to anticipate if the user will move Tensors to a new device within the run_fn itself. Therefore, if you move Tensors to a new device (“new” meaning not belonging to the set of [current device + devices of Tensor arguments]) within run_fn , deterministic output compared to non-checkpointed passes is never guaranteed.

I have a basic question:

  • What is the meaning of this sentence:

This can cause persistent states like the RNG state to be advanced than they would without checkpointing.

What I understand is, random number generator only needed for weight initialization of the layers which happens in model initialization section not in the forward method of the model. Am I right?
If so, what is RNG state has anything to do with checkpointing?
A simple example would be very helpful.

Random number generator is needed for layers such as Dropout as well. If you checkpoint module which contains Dropout you want you dropout mask to be the same when doing forward pass and when re-computing tensors in backward pass.