I think we stumbled upon very similar issue regarding MAML implementation. I haven’t tested it yet but I think yes we can do the updates after each task without storing computational graphs for all tasks.

```
import copy
import torch
import torch.nn as nn
from learn2learn.utils import clone_module, update_module
batch_size = 3
task_size = 2
lr = 1e-3
# set up weight manually
net = nn.Sequential(nn.Linear(2, 1, False))
with torch.no_grad():
net[0].weight = nn.Parameter(torch.tensor([[.1, .2]]))
# dataset = [torch.randn(batch_size, 2) for _ in range(task_size)]
dataset = [torch.Tensor([[1., 2.], [2., 1.], [2., 2.]])]
meta_grads = []
for data in dataset:
temp_net = clone_module(net)
# first inner-loop
loss_val = nn.functional.mse_loss(net(data), torch.zeros((batch_size, 1)))*3.
grad = torch.autograd.grad(loss_val, net.parameters(), create_graph=True, retain_graph=True)
temp_net = update_module(temp_net, updates=tuple(-lr*g for g in grad))
print(grad)
print(list(temp_net.parameters())[0])
# second inner-loop
loss_val = nn.functional.mse_loss(temp_net(data), torch.zeros((batch_size, 1)))*3.
grad = torch.autograd.grad(loss_val, temp_net.parameters(), create_graph=True, retain_graph=True)
temp_net = update_module(temp_net, updates=tuple(-lr*g for g in grad))
print(grad)
print(list(temp_net.parameters())[0])
# outer-loop
loss_val = nn.functional.mse_loss(temp_net(data), torch.zeros((batch_size, 1)))*3.
grad = torch.autograd.grad(loss_val, net.parameters())
print(grad)
# store meta-gradient of each task
meta_grads.append(grad)
# take average of meta_grads and update_module on net
```

I am not yet checking the computation. I’ll get back to this when I finish checking it.

Edit:

- I changed the code a little bit for correctness + using pre-determined weight and inputs so I can compare to manually computed result.
- I move the print function and the
`lr*grad`

so the code is more readable.