Quick update on this…
Activation checkpointing showed the best GPU memory reduction (~40%) for this application with large input sizes (medical image segmentation). FSDP wasn’t beneficial due to the small number of parameters/gradients.
Thanks for the help @agu!