I want to implement a fault tolerance mechanism in the distributed training. I wander if there is a way to fully clear all the activation created in a forward pass as there might exist some data that cause OOM in one process and if I clear activation manually, the training process can recover and ready to process other data.
I believe this should be freed by default, but if you want to, you could try using a forward hook (after the function has been triggered) and then manually delete the activations, and see if this helps.
The documentation on the forward hooks are here: Module — PyTorch 2.4 documentation
I think you are right. The memory cannot be automatically released only when I wrapped the model with DeepSpeed and it’s all fine with original Python.