Is there any way to prepend to the computational graph?

I’m writing a second-order optimizer that optimizes model parameters in a low-dimensional reparameterized space. Ideally, I’d like to:

  1. Take a step in the low-dimensional space
  2. Pushforward the low-dimensional parameters to the high-dimensional model space
  3. Compute the loss and backpropagate through both the model and the pushforward function, all the way to the low-dimensional space.

Unfortunately I’m finding it difficult to prepend the pushforward function to the computational graph in the model, since in-place updates of leaf variables break autograd.

Is there a canonical/recommended way to approach this? Or should I just abandon torch.nn entirely and write everything as a pure function?

Thanks for your help!