Checkpointing LlamaDecodingLayer or LlamaAttention to help with memory limitations?

I’m working with the Huggingface Llama3 models ( transformers/src/transformers/models/llama/modeling_llama.py at main · huggingface/transformers · GitHub ) and having memory issues due to sequence length. It seems like if I could checkpoint either the LlamaDecodingLayer or the LlamaAttention blocks this would be a big help. Ideally there would be a switch on FSDP2 fully_shard that would let me checkpoint modules as I shard them, but sadly I don’t see such a thing.

I understand (at a high level) the torch.utils.checkpoint (torch.utils.checkpoint — PyTorch 2.8 documentation) but I don’t see how to apply it to an existing model. Is there a way to do so without hacking the model itself?

Even some pseudo-code to point me in the right direction would be welcome.