Calculating meta-gradients with approximate inner-loop gradient

I am currently working on a re-implementation of a (model agnostic) meta-learning paper (Lazy MAML - ICCV2021), which approximates the 'gradient through the inner-loop process as a single step. As such, now the inner-learner could perform many steps (say k=10…20). Then to calculate the meta-gradient we approximate the expensive computation as though a single inner-loop step (or leap) has been taken.

Converting this logic into actual code gives me some headaches, as I fail to see how this can be properly implemented. My current idea boils down to the following four steps, on the last one I run into problems.

First, let us define that φ denotes the inner-loop model, and θ the outer-loop model (meta-learner)

  1. Do the inner-loop optimization. For this, I plan to use autograd, without retain_graph enabled. For my baseline implementation of MAML/ANIL/BOIL with learn2learn this needs to be enabled to allow for the outer-loop backprop.

  2. Calculate the single step from θ to φ, i.e. grad = φ - θ, for which we can loop over their parameters, and calculate the single gradient that approximates the inner-loop. I will call this the ‘derived’ gradient.

  3. Using φ we can calculate the loss on the query set, which we need to calculate the meta-gradient.

  4. Now I need to perform the backpropagation ‘through the inner-loop’ using my ‘derived’ gradient update.

For step 4, I consider the following, but I feel like I am abusing some of PyTorch’s functionality. First I could perform a normal MAML update using θ, to get ѱ, and perform a backward pass with autograd.grad with retain_graph enabled. (Using learn2learn’s update_module functionality to make the updates differentiable) However, using backward hooks I can set the gradients to my ‘derived’ gradients instead of the actual modules gradient, such that ѱ == ɸ, but with the computational graph needed.
Now, I have a (fudged) computational graph for autograd and can perform the second-order derivative with the loss that I get from step 3. However, I would need to re-compute this loss (after first taking a step with my ‘derived’ gradient), as ѱ and the loss are otherwise unrelated.

Although I feel like this should work, I feel uncertain about my approach and feel like there should be a more sophisticated way. (Or a correct way, if I am mistaken in my approach). E.g. constructing the computational graph differently, or using a method such as implicit MAML (without l2 weight regularisation). I am currently writing code to verify my idea against my MAML implementation to see if in the case of a single inner-update step my implementation.