Hi,
I want to go from a torch::jit::Operator or a torch::jit::Node for an op it in a particular graph to the torch::jit::Graph that represents its derivative.
I think this information is encoded in some form in derivatives.yaml, but I’m not sure how to get it from the APIs available.
Either Python Node’s / Graphs or C++ would work for my purposes.
So there’s no way to go from a ScriptModule or jit::Graph to the equivalent ScriptModule or jit::Graph for the gradient / backwards?
Could we somehow leverage torch.jit.trace() to produce a graph from a backwards pass?
Summary of the use case: Currently torch.onnx.export() has an ATEN_FALLBACK mode, which means if some ATen op doesn’t have an ONNX equivalent, the ATen op is exported as-is. Some consumers of the exported ONNX model need to be able to compute the backwards / gradient graph. They currently can do this for a pure ONNX graph, but I’d like to extend this to ATen ops as well. Hence the request.