Gradient Checkpointing basic example?

Hello,

I am trying to implement gradient checkpointing in my code to circumvent GPU memory limitations, and I found a Pytorch implementation . However I could not find any examples anywhere online. All I see right now is:

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)

This is for sequential models - I could not find anything for a non-sequential model, even though it is implemented in pytorch. Also, I am not sure what these input vars and chunks are and how the whole thing gels with the rest of the code. It would be a very useful resource if someone can provide a very simple self contained example of how this can be used to train a model. Thanks!

I used this tutorial from @Priya_Goyal to see how different models should be used for checkpointing.

8 Likes