I’m implementing MAML, but currently I’m struggling to understand how to calculate the gradient of the outer loss w.r.t model’s parameters before any inner gradient step. So I made a naive example, but I’m not sure if I missed anything ?

Thank you for your answer, I knew that I missed something!
This is my toy example when trying to under stand MAML and I think I kind of understand and familiar with pytorch.
When implement MAML, I have a model (nn.Module) with parameters W (stack of CNN and Linear). How can I still keep my W but also forward with W_after_update so that I can still backward the outer_loss to my starting W ?

I’m afraid it is a tricky thing to do.
If you follow MAML quite closely, I would recommend using https://github.com/facebookresearch/higher/ that takes care of all this state management for you.

Thank you for your reply ! Higher looks really promising
Since I understood the concept of pytorch grad, I will try to implement it with normal pytorch (I think about deep copy and passing back grad_output)
Once again thank you so much for your help