I want to go from a
torch::jit::Operator or a
torch::jit::Node for an op it in a particular graph to the
torch::jit::Graph that represents its derivative.
I think this information is encoded in some form in derivatives.yaml, but I’m not sure how to get it from the APIs available.
Either Python Node’s / Graphs or C++ would work for my purposes.
Thanks for any tips!
There is currently no API for this, though I’m curious what is your use case for this?
Thanks for taking the time to reply @soulitzer.
So there’s no way to go from a
jit::Graph to the equivalent
jit::Graph for the gradient / backwards?
Could we somehow leverage
torch.jit.trace() to produce a graph from a backwards pass?
Summary of the use case: Currently
torch.onnx.export() has an ATEN_FALLBACK mode, which means if some ATen op doesn’t have an ONNX equivalent, the ATen op is exported as-is. Some consumers of the exported ONNX model need to be able to compute the backwards / gradient graph. They currently can do this for a pure ONNX graph, but I’d like to extend this to ATen ops as well. Hence the request.