Backward pass through ONNX object

I would like to make an onnx model differentiable. As I understand exporting to ONNX does not export the AutoGrad graph, is there anyway to reconstruct if after loading?

I am aware of torch-ort but to me it looks like it only works with nn.Module objects, i.e. original python pytorch models ? (see example here, here and here)

Can I in any way load an ONNX exported model and get pytorch or onnx-runtime to reconstruct the backward graph?

Alternatively, can I get onnx to export backward graph of a PyTorch nn.Module model? So that I can run it with onnx runtime?

Background: I want to work with physics based models, where I could easily write forward “energy” function, and can use their gradient (“forces”) in my simulations. At present we need either numerical diff, or analytic expressions to be derived before hand.

1 Like