Long sequential learning with Torch.utils.checkpoint.checkpoint

I would like to use model checkpointing to run a CNN + RNN in an end-to-end fashion. One does not necessarily need checkpointing for this problem but I was hoping to backprop into the cnn using a very long sequence e.g. 500 images. Without any further measures to reduce the memory footprint this inevitably breaks any training process due to the cuda out of memory error.

I found some helpful resources:
https://pytorch.org/docs/stable/checkpoint.html

If I understood model checkpointing correctly it should be possible to skip the gradient computation in the forward pass and in the backward pass calculate the gradient by rerunning a forward pass for each segment - trading speed for less memory consumption.

I am still a bit clueless on how to implement this as there seems to be no complete example with an RNN available.

Another resource I found is from fairscale

Is this a similar kind of model checkpointing compared to pytorch checkpointing?

Any help or pointers to additional reading material is highly appreciated.

1 Like