Pytorch implementation of MAML that works with module-style networks?


I see there are several pytorch implementations of MAML around, eg:

… however they all seem to work by using torch.nn.functional, rather than using the nn.Module form, and then eg somehow iterating over model.parameters().

Is this a fundamental limitation of pytorch? Or is there some way of running MAML, without having to rewrite the entire network in functional form?


Wow I was asking myself the exact same thing. I am working on an adaptation of MAML-based iteration for multi-agent learning (, and it worked since I am not using nn.Module.

Now I want to try it on larger networks (CNN with a recurrent layer) but I have to create everything by hand with nn.functional and implement my own gradient descents.

So far, it seems that Pytorch can’t differentiate through Module parameters updates.

I have created this feature request:

Let’s see what could be done.