Gradient Checkpointing in Pytorch?

I’m curious to hear the developers thoughts on Gradient Checkpointing.

Memory limitations are one of the biggest restrictions I encounter both with pytorch and with deep learning in general and this seems like an interesting and possibly fruitful solution, particularly if it were baked right into the library itself.

There is some WIP on checkpointing at https://github.com/pytorch/pytorch/pull/4594