Tracing provenance of nodes in the backward pass


aot_module can be used to get handle on the traced forward && backwards passes by autograd. For example,

def print_aten_ops(gm, example_inputs):
  from functorch.compile import aot_module
  def trace_printer(gm, _):
    return gm
  return aot_module(gm, fw_compiler=trace_printer, bw_compiler=trace_printer)
model_compiled = torch.compile(model, fullgraph=True, backend=print_aten_ops)

I’m trying trace the generation of nodes in the backward pass, so I can attribute the node(s) in the forward pass which caused the creation of the said node in the backward graph. I see there is facility to add hooks to the modules/tensors, but that doesn’t necessarily let me get the effect I need.

Ideally, the nodes in the returned fx graph has been “annotated” with some node identified from the forward pass computation.

Any pointers will be very helpful!