Differentiating through a gradient

I have a model f parametrized by w, input data x and labelled output y. I’m searching for a synthetic input and output pair (x0, y0), and a step-size e such that – letting w' = w - e d(loss(f(x0, w), y0))/dwloss(f(w', x), y) is minimized. To find those through gradient descent, I need to compute the gradient of loss(f(w', x), y) with respect to x0, y0, and e.

When f is a relatively simple neural network I can compute the update directly fairly easily, but I would like to use torch to compute those gradients directly. Is it possible to differentiate through the taking of a gradient like this? I’m new to Torch, and I understand this might be possible by using keep_graph = True. However, I’ve pained a bit to make it work.

Concretely, if f is a nn.Sequential model piling on various layers, how would one write the update to the parameters so that I can then take a gradient with respect to x0 and e. It seems straightforward to update the parameters of the various layers found in nn. when using an off-the-shelf optimizer but this requires something a bit less standard. Is it still possible?

The general idea is to greedily construct a synthetic “curriculum” which can be used to train a network from scratch quickly. It might be possible to devise a curriculum on a small network and use it to quickly initialize a much larger network for instance. It helps if you think as x and y as the whole dataset, and x0 and y0 as a synthetic mini-batch, typically with only one sample.