Example of second order derivative (single variable)

I’m implementing MAML, but currently I’m struggling to understand how to calculate the gradient of the outer loss w.r.t model’s parameters before any inner gradient step. So I made a naive example, but I’m not sure if I missed anything ?

import torch
from torch.autograd import grad

# f(x) = 2x + 3

x_train = torch.tensor(2., requires_grad=True)
y_train = torch.tensor(7., requires_grad=True)  # 2 * 2 + 3 = 7

x_test = torch.tensor(0.2, requires_grad=True)
y_test  = torch.tensor(3.4, requires_grad=True) # 0.2*2 + 3 = 3.4

W = torch.tensor(1., requires_grad=True)

n_inner_update = 3

for step in range(2000):
    y_pred = x_train * W + 3.
    loss = torch.abs(y_pred - y_train)
    inner_grad = grad(loss, W, create_graph=True)[0]

    W_after_update = W - inner_grad * 0.002  # an inner update

    for inner in range(n_inner_update - 1):
        y_pred = x_train * W_after_update + 3.
        loss = torch.abs(y_pred - y_train)
        inner_grad = grad(loss, W_after_update, create_graph=True)[0]
        W_after_update = W_after_update - inner_grad* 0.002

    y_pred_test = x_test * W_after_update + 3.
    outer_loss = torch.abs(y_pred_test - y_test)
    outer_loss_wrt_W = grad(outer_loss, W)[0]
    W.data -= outer_loss_wrt_W * 0.01
    print(loss.item())

Note that the right way to update the weight is with:

with torch.no_grad():
    W -= outer_loss_wrt_W * 0.01

.data should not be used anymore!

Your code looks good at first glance. What is the issue that you have with it?

1 Like

Thank you for your answer, I knew that I missed something!
This is my toy example when trying to under stand MAML and I think I kind of understand and familiar with pytorch.
When implement MAML, I have a model (nn.Module) with parameters W (stack of CNN and Linear). How can I still keep my W but also forward with W_after_update so that I can still backward the outer_loss to my starting W ?

I’m afraid it is a tricky thing to do.
If you follow MAML quite closely, I would recommend using https://github.com/facebookresearch/higher/ that takes care of all this state management for you.

Thank you for your reply ! Higher looks really promising
Since I understood the concept of pytorch grad, I will try to implement it with normal pytorch (I think about deep copy and passing back grad_output)
Once again thank you so much for your help

You can do it with pytorch (that’s what higher does :stuck_out_tongue: ).
Be careful that the nn module is not built for that unfortunately.
keep in mind in particular:

  • nn.Parameter can only contains leafs (gradient with no history).
  • optim.step() is NOT differentiable and won’t propagate gradients.

Good luck!
Don’t hesitate to post here if you need inputs.

1 Like