I would like to use model checkpointing to run a CNN + RNN in an end-to-end fashion. One does not necessarily need checkpointing for this problem but I was hoping to backprop into the cnn using a very long sequence e.g. 500 images. Without any further measures to reduce the memory footprint this inevitably breaks any training process due to the cuda out of memory error.
If I understood model checkpointing correctly it should be possible to skip the gradient computation in the forward pass and in the backward pass calculate the gradient by rerunning a forward pass for each segment - trading speed for less memory consumption.
I am still a bit clueless on how to implement this as there seems to be no complete example with an RNN available.
Another resource I found is from fairscale
Is this a similar kind of model checkpointing compared to pytorch checkpointing?
Any help or pointers to additional reading material is highly appreciated.