Hi all! This is my first post here and I’m looking forward to learning more about PyTorch internals.
I have a large model that I’m running through torch.compile and aot_autograd. The model has various graph breaks, mostly due to calls to torch.nonzero. For small models it’s pretty clear how to “stitch” these graphs together and run the entire compiled flow for both the forward and backward passes. But for large models this process is much more error prone – I sometimes have hundreds of intermediate tensors to marshal between each graph and its corresponding backward graph and it’s not always clear which tensor each “tangent” corresponds to because I can have multiple tangents of the same shape and dtype.
Is there a way I can use PyTorch itself to untangle this mess for me automatically? In a perfect world I would have a Python function that given input arguments to the model shows me how to 1) invoke each forward graph, 2) what operations need to occur between graphs (if any), and 3) how to invoke each backward graph given the forward graph results. This is essentially what PyTorch does when someone invokes the compiled model (and autograd, I guess) but trying to replicate that process through the debugger seems hopeless because of all the stack layers one would have to wade through.
Thanks in advance for any tips here!