Parameter updates with second order grads and in-place copy operation

Hi all , my question is related to the following question.

Example of second order derivative (single variable)

In the above post, the questioner proposed a MAML example.

My 1st question is : can torch.autograd work for in-place copy operation to update parameters with second order gradients ? For example, can the following toy MAML example work ?


import torch
from torch.autograd import grad

# f(x) = 2x + 3

x_train = torch.tensor(2., requires_grad=True)
y_train = torch.tensor(7., requires_grad=True)  # 2 * 2 + 3 = 7

x_test = torch.tensor(0.2, requires_grad=True)
y_test  = torch.tensor(3.4, requires_grad=True) # 0.2*2 + 3 = 3.4

W = torch.nn.Parameter(torch.tensor(1., requires_grad=True))

n_inner_update = 3

for step in range(20):
    y_pred = x_train * W + 3.
    loss = torch.abs(y_pred - y_train)
    inner_grad = grad(loss, W, create_graph=True)[0]

    for inner in range(n_inner_update - 1):
        y_pred = x_train * W + 3.
        loss = torch.abs(y_pred - y_train)
        inner_grad = grad(loss, W, create_graph=True)[0]
        with torch.no_grad():
            W.copy_(W - inner_grad* 0.002)

    y_pred_test = x_test * W + 3.
    outer_loss = torch.abs(y_pred_test - y_test)
    outer_loss_wrt_W = grad(outer_loss, W)[0]
    with torch.no_grad():
        W.copy_(W - outer_loss_wrt_W * 0.01)
    print(loss.item())

My Second question is : Why the following errors happns if torch.no_grad() is removed from the line ?
Could a in-place copy operation be a node of computational graphs whose nodes have gradients ?


RuntimeError Traceback (most recent call last)
in
24 inner_grad = grad(loss, W, create_graph=True)[0]
25 # with torch.no_grad():
—> 26 W.copy_(W - inner_grad* 0.002)
27
28 y_pred_test = x_test * W + 3.

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.