Optimizing trade-off memory usage with computation - Checkpoint


I am following the OpenAI implementation for Training Deep Nets with Sublinear Memory Cost).

I would like to know how the “no-longer needed” intermediate features are freed in PyTorch, since we can hardly read the source C++ code for the backward function.

In the above implementation, during the backward process, each f node and b node should be deleted from the memory as soon as it has been used and no longer needed, so that we can eliminate unnecessary memory usage. So they implemented the checkpoint approach to recompute some of the feature maps, and store only few of them strategically to optimize memory cost without increasing the computation time too much.

In PyTorch, are the intermediate results also freed when they are no-longer needed?
**How could I keep track of my GPU usage during the entire forward-backward pass is computed?
Can anyone help me with this please?


Yes intermediate results are deleted as soon as they are not needed.
You can also use the torch.utils.checkpoint() to get checkpointing.

1 Like