Hello everyone,
I am training a Llama2 model using PyTorch, and I noticed that directly using loss.backward()
takes too long to complete within the specified time slot. Therefore, I want to split the Llama2 model into multiple submodules based on decoder layers, with each submodule containing several decoder layers.
However, I am unsure how to control the staged execution of backward propagation. I found that once the forward pass is completed, the computation graph is automatically built, and executing loss.backward()
will directly execute the entire backward process. Is there a good way to control the staged execution of backward propagation?
Thank you!