I am re-implementing the supervised learning experiments from Model-Agnostic Meta Learning (MAML) in PyTorch.
The goal is to learn features that are “most fine-tune-able.” This is achieved by taking gradient step(s) in the direction that maximizes performance on the validation set given a step(s) on the training set. This requires second derivatives with respect to the parameters. See Algorithm 1 in the paper.
Where I am stuck is that I need to do:
(1) inner loop: a forward pass on the training example, take gradients with respect to the parameters (2) meta loop: do a forward pass with the updated parameters on a validation example, then take another gradient wrt the original parameters and backprop through the first gradient (thus the second derivative).
From the Improved WGAN implementation, I see that I can take the gradient and retain the graph, allowing me to then take another gradient. But I don’t see how I can do the second forward pass without updating the parameters via
opt.step(). Do I need to have two graphs, one where I cache the old parameters for the meta-update, and one where I allow the parameters to update in the inner loop?
Thanks for your help!
EDIT: I understand now that I need to add variables to the graph for each gradient of each variable. Then in the meta-update, the gradient of these gradients can be taken with a backward pass over the augmented graph.