I found that the current activation checkpointing wrapper pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py at main · pytorch/pytorch · GitHub can make the model fail at the torchscripting stage, especially in torch.jit.save.
RuntimeError:
Could not export Python function call '_NoopSaveInputs'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.8/site-packages/torch/autograd/function.py(539): apply
/usr/local/lib/python3.8/site-packages/torch/utils/checkpoint.py(1203): _checkpoint_without_reentrant_generator
/usr/local/lib/python3.8/site-packages/torch/utils/checkpoint.py(457): checkpoint
/usr/local/lib/python3.8/site-packages/torch/_dynamo/external_utils.py(17): inner
/usr/local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py(328): _fn
/usr/local/lib/python3.8/site-packages/torch/_compile.py(24): inner
But if I just use the raw api torch.utils.checkpoint — PyTorch 2.6 documentation, then it is fine.
Wonder the checkpointing wrapper essentially does not support the torchscripting?