I am currently working on a reimplementation of a (model agnostic) metalearning paper (Lazy MAML  ICCV2021), which approximates the 'gradient through the innerloop process as a single step. As such, now the innerlearner could perform many steps (say k=10…20). Then to calculate the metagradient we approximate the expensive computation as though a single innerloop step (or leap) has been taken.
Converting this logic into actual code gives me some headaches, as I fail to see how this can be properly implemented. My current idea boils down to the following four steps, on the last one I run into problems.
First, let us define that φ denotes the innerloop model, and θ the outerloop model (metalearner)

Do the innerloop optimization. For this, I plan to use autograd, without retain_graph enabled. For my baseline implementation of MAML/ANIL/BOIL with learn2learn this needs to be enabled to allow for the outerloop backprop.

Calculate the single step from θ to φ, i.e. grad = φ  θ, for which we can loop over their parameters, and calculate the single gradient that approximates the innerloop. I will call this the ‘derived’ gradient.

Using φ we can calculate the loss on the query set, which we need to calculate the metagradient.

Now I need to perform the backpropagation ‘through the innerloop’ using my ‘derived’ gradient update.
For step 4, I consider the following, but I feel like I am abusing some of PyTorch’s functionality. First I could perform a normal MAML update using θ, to get ѱ, and perform a backward pass with autograd.grad
with retain_graph enabled. (Using learn2learn’s update_module functionality to make the updates differentiable) However, using backward hooks I can set the gradients to my ‘derived’ gradients instead of the actual modules gradient, such that ѱ == ɸ, but with the computational graph needed.
Now, I have a (fudged) computational graph for autograd and can perform the secondorder derivative with the loss that I get from step 3. However, I would need to recompute this loss (after first taking a step with my ‘derived’ gradient), as ѱ and the loss are otherwise unrelated.
Although I feel like this should work, I feel uncertain about my approach and feel like there should be a more sophisticated way. (Or a correct way, if I am mistaken in my approach). E.g. constructing the computational graph differently, or using a method such as implicit MAML (without l2 weight regularisation). I am currently writing code to verify my idea against my MAML implementation to see if in the case of a single innerupdate step my implementation.