I am trying to implement gradient checkpointing in my code to circumvent GPU memory limitations, and I found a Pytorch implementation . However I could not find any examples anywhere online. All I see right now is:
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
This is for sequential models - I could not find anything for a non-sequential model, even though it is implemented in pytorch. Also, I am not sure what these input vars and chunks are and how the whole thing gels with the rest of the code. It would be a very useful resource if someone can provide a very simple self contained example of how this can be used to train a model. Thanks!