I’m writing a second-order optimizer that optimizes model parameters in a low-dimensional reparameterized space. Ideally, I’d like to:
- Take a step in the low-dimensional space
- Pushforward the low-dimensional parameters to the high-dimensional model space
- Compute the loss and backpropagate through both the model and the pushforward function, all the way to the low-dimensional space.
Unfortunately I’m finding it difficult to prepend the pushforward function to the computational graph in the model, since in-place updates of leaf variables break autograd.
Is there a canonical/recommended way to approach this? Or should I just abandon torch.nn
entirely and write everything as a pure function?
Thanks for your help!