Hi,
aot_module
can be used to get handle on the traced forward && backwards passes by autograd. For example,
def print_aten_ops(gm, example_inputs):
from functorch.compile import aot_module
def trace_printer(gm, _):
return gm
return aot_module(gm, fw_compiler=trace_printer, bw_compiler=trace_printer)
model_compiled = torch.compile(model, fullgraph=True, backend=print_aten_ops)
I’m trying trace the generation of nodes in the backward pass, so I can attribute the node(s) in the forward pass which caused the creation of the said node in the backward graph. I see there is facility to add hooks to the modules/tensors, but that doesn’t necessarily let me get the effect I need.
Ideally, the nodes in the returned fx graph has been “annotated” with some node identified from the forward pass computation.
Any pointers will be very helpful!