How can you do memory efficient MAML?

The code is indeed correct. I compare it with the following analytic computation.

import torch
import torch.nn as nn

# gradient
def closed_form_gradient(x):
  return torch.tensor([[18.*x[0][0] + 16.*x[0][1], 16.*x[0][0] + 18.*x[0][1]]])

# (constant) hessian
hessian = torch.tensor([[18., 16.], [16., 18.]])

# set up weight manually
net = nn.Sequential(nn.Linear(2, 1, False))
with torch.no_grad():
  net[0].weight = nn.Parameter(torch.tensor([[.1, .2]]))

lr_ = 1e-3

## GRADIENT COMPUTATION USING CLOSED FORM SOLUTION
grad_in_1_analytic        = closed_form_gradient(net[0].weight)
fast_weight_in_1_analytic = net[0].weight - lr_ * grad_in_1_analytic

print("Grad:", grad_in_1_analytic)
print("Parameter:", fast_weight_in_1_analytic)
print()

grad_in_2_analytic        = closed_form_gradient(fast_weight_in_1_analytic)
fast_weight_in_2_analytic = fast_weight_in_1_analytic - lr_ * grad_in_2_analytic

print("Grad:", grad_in_2_analytic)
print("Parameter:", fast_weight_in_2_analytic)
print()

grad_in_3_analytic  = closed_form_gradient(fast_weight_in_2_analytic)
grad_out_analytic   = grad_in_3_analytic @ (torch.eye(2) - lr_ * hessian) @ (torch.eye(2) - lr_ * hessian)

print("Grad:", grad_out_analytic)

To answer your main question, this solution does not store computation graph for all tasks. Instead, it constructs computation graph for one task, compute the meta-gradient, store the meta-gradient in a list that can be used (i.e. averaged) to meta-update the main network, and then repeat.

2 Likes