We haven’t implemented checkpointing
in PyTorch yet, we are thinking about this.
However, if you want to manually create some checkpointing logic, it’s not that difficult.
See this thread where someone trades of memory for timesteps, sort of like a manual scan_checkpoint: