Hi all , my question is related to the following question.
In the above post, the questioner proposed a MAML example.
My 1st question is : can torch.autograd work for in-place copy operation to update parameters with second order gradients ? For example, can the following toy MAML example work ?
import torch from torch.autograd import grad # f(x) = 2x + 3 x_train = torch.tensor(2., requires_grad=True) y_train = torch.tensor(7., requires_grad=True) # 2 * 2 + 3 = 7 x_test = torch.tensor(0.2, requires_grad=True) y_test = torch.tensor(3.4, requires_grad=True) # 0.2*2 + 3 = 3.4 W = torch.nn.Parameter(torch.tensor(1., requires_grad=True)) n_inner_update = 3 for step in range(20): y_pred = x_train * W + 3. loss = torch.abs(y_pred - y_train) inner_grad = grad(loss, W, create_graph=True) for inner in range(n_inner_update - 1): y_pred = x_train * W + 3. loss = torch.abs(y_pred - y_train) inner_grad = grad(loss, W, create_graph=True) with torch.no_grad(): W.copy_(W - inner_grad* 0.002) y_pred_test = x_test * W + 3. outer_loss = torch.abs(y_pred_test - y_test) outer_loss_wrt_W = grad(outer_loss, W) with torch.no_grad(): W.copy_(W - outer_loss_wrt_W * 0.01) print(loss.item())
My Second question is : Why the following errors happns if torch.no_grad() is removed from the line ?
Could a in-place copy operation be a node of computational graphs whose nodes have gradients ?
RuntimeError Traceback (most recent call last)
24 inner_grad = grad(loss, W, create_graph=True)
25 # with torch.no_grad():
—> 26 W.copy_(W - inner_grad* 0.002)
28 y_pred_test = x_test * W + 3.
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.