[resolved] Implementing MAML in PyTorch


I am re-implementing the supervised learning experiments from Model-Agnostic Meta Learning (MAML) in PyTorch.

The goal is to learn features that are “most fine-tune-able.” This is achieved by taking gradient step(s) in the direction that maximizes performance on the validation set given a step(s) on the training set. This requires second derivatives with respect to the parameters. See Algorithm 1 in the paper.

Where I am stuck is that I need to do:
(1) inner loop: a forward pass on the training example, take gradients with respect to the parameters (2) meta loop: do a forward pass with the updated parameters on a validation example, then take another gradient wrt the original parameters and backprop through the first gradient (thus the second derivative).

From the Improved WGAN implementation, I see that I can take the gradient and retain the graph, allowing me to then take another gradient. But I don’t see how I can do the second forward pass without updating the parameters via opt.step(). Do I need to have two graphs, one where I cache the old parameters for the meta-update, and one where I allow the parameters to update in the inner loop?

Thanks for your help!

EDIT: I understand now that I need to add variables to the graph for each gradient of each variable. Then in the meta-update, the gradient of these gradients can be taken with a backward pass over the augmented graph.


Were you able to successfully implement it? If yes, can you explain how you did it?

I have a related question here

Yes I have an implementation, but I’ve been unable to replicate the Omniglot experiments from the paper. Let me clean up the code and I will make it public soon.

Here is a link to the MAML code: https://github.com/katerakelly/pytorch-maml
I was able to replicate the Omniglot experiments with it.

1 Like

is there a easy way to train networks within the maml framework without having to code up the functional form ? as this would make training of deeper networks easier

1 Like

Can you post your implemention? I am also curious about the “module” style implemention about maml.

This is super late, but there are now libraries that nicely implement this. See GitHub - learnables/learn2learn: A PyTorch Library for Meta-learning Research for a stateful solution and GitHub - facebookresearch/higher: higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps. for a stateless solution :slight_smile: