Question about activation checkpoint with FSDP

I found that PyTorch’s FSDP has its own wrapping function (apply_activation_checkpointing_wrapper) for the activation checkpoint.

I want to know the difference between apply_activation_checkpointing_wrapper and gradient_checkpointing_enable.
When I want to apply activation checkpointing with PyTorch’s FSDP, should I apply the function instead of gradient_checkpointing_enable provided by Huggingface models such as GPT2?

I think that gradient_checkpointing_enable() is HuggingFace’s own built-in method that works because HuggingFace models have manual activation checkpointing calls in the model source code that can be enabled/disabled.

apply_activation_checkpointing_wrapper() can work for general models (not just HuggingFace) since the user must pass the criteria for checkpointing. If you are using a HuggingFace model, you can try using the HuggingFace gradient_checkpointing_enable() since those checkpoints have been hand-picked. Though, I am not familiar with the compatibility with FSDP.

1 Like