How would we run a “partial” backwards on one subgraph, pause, then continue upward on a previous subgraph in the same overall backward pass?
My current idea, which works to some extent, is to freeze all but one subgraph, call backwards, then refreeze it, unfreeze the next section, and recall backwards. So essentially repeated backwards calls to compose a part-by-part overall backwards call.
However, this necessitates the use of retain_graph. Due to memory constraints, I’d prefer not to do this - instead I’d like to figure out a way for all of these to be part of the same backward pass. Is this possible? I thought perhaps it might require rewriting the execution engine, but is it possible with the higher-level APIs?