How to Control Staged Backward Execution in PyTorch for Llama2 Model?

Hello everyone,

I am training a Llama2 model using PyTorch, and I noticed that directly using loss.backward() takes too long to complete within the specified time slot. Therefore, I want to split the Llama2 model into multiple submodules based on decoder layers, with each submodule containing several decoder layers.

However, I am unsure how to control the staged execution of backward propagation. I found that once the forward pass is completed, the computation graph is automatically built, and executing loss.backward() will directly execute the entire backward process. Is there a good way to control the staged execution of backward propagation?

Thank you!