The code is indeed correct. I compare it with the following analytic computation.
import torch
import torch.nn as nn
# gradient
def closed_form_gradient(x):
return torch.tensor([[18.*x[0][0] + 16.*x[0][1], 16.*x[0][0] + 18.*x[0][1]]])
# (constant) hessian
hessian = torch.tensor([[18., 16.], [16., 18.]])
# set up weight manually
net = nn.Sequential(nn.Linear(2, 1, False))
with torch.no_grad():
net[0].weight = nn.Parameter(torch.tensor([[.1, .2]]))
lr_ = 1e-3
## GRADIENT COMPUTATION USING CLOSED FORM SOLUTION
grad_in_1_analytic = closed_form_gradient(net[0].weight)
fast_weight_in_1_analytic = net[0].weight - lr_ * grad_in_1_analytic
print("Grad:", grad_in_1_analytic)
print("Parameter:", fast_weight_in_1_analytic)
print()
grad_in_2_analytic = closed_form_gradient(fast_weight_in_1_analytic)
fast_weight_in_2_analytic = fast_weight_in_1_analytic - lr_ * grad_in_2_analytic
print("Grad:", grad_in_2_analytic)
print("Parameter:", fast_weight_in_2_analytic)
print()
grad_in_3_analytic = closed_form_gradient(fast_weight_in_2_analytic)
grad_out_analytic = grad_in_3_analytic @ (torch.eye(2) - lr_ * hessian) @ (torch.eye(2) - lr_ * hessian)
print("Grad:", grad_out_analytic)
To answer your main question, this solution does not store computation graph for all tasks. Instead, it constructs computation graph for one task, compute the meta-gradient, store the meta-gradient in a list that can be used (i.e. averaged) to meta-update the main network, and then repeat.