Pytorch implementation of MAML that works with module-style networks?

Hi,

I see there are several pytorch implementations of MAML around, eg:

… however they all seem to work by using torch.nn.functional, rather than using the nn.Module form, and then eg somehow iterating over model.parameters().

Is this a fundamental limitation of pytorch? Or is there some way of running MAML, without having to rewrite the entire network in functional form?

3 Likes

Wow I was asking myself the exact same thing. I am working on an adaptation of MAML-based iteration for multi-agent learning (https://github.com/alexis-jacq/LOLA_DiCE), and it worked since I am not using nn.Module.

Now I want to try it on larger networks (CNN with a recurrent layer) but I have to create everything by hand with nn.functional and implement my own gradient descents.

So far, it seems that Pytorch can’t differentiate through Module parameters updates.

I have created this feature request: https://github.com/pytorch/pytorch/issues/12659

Let’s see what could be done.