How can you do memory efficient MAML?

Hi I’m trying to do MAML with learn2learn but it occurred to me that the loss for the meta-model is extremely expensive in terms of memory because it includes the computational graph of each base learner for each task inside a mini batch which is completely unnecessary.

What I’m trying to do is just accumulate the weight updates from each base learner (+ task pair) without the entire computational graph. How can you do this without actually applying the weight updates?

For example is it possible/easy to store a copy of the model where you zero all parameters then apply the weight updates from the real model to that model one task at a time. Then once you are done add that models parameters back into the original model?

I’m particularly interested in if anyone knows of a better solution?

I think we stumbled upon very similar issue regarding MAML implementation. I haven’t tested it yet but I think yes we can do the updates after each task without storing computational graphs for all tasks.
$\dpi{100}&space;\small&space;\begin{}\\&space;\theta_1&space;\leftarrow&space;\theta&space;-&space;\alpha&space;\frac{\partial&space;f(\theta)}&space;{\partial&space;\theta}\\&space;\theta_2&space;\leftarrow&space;\theta_1&space;-&space;\alpha&space;\frac{\partial&space;f(\theta_1)}{\partial&space;\theta_1}\\&space;\theta&space;\leftarrow&space;\theta&space;-&space;\alpha&space;\frac{\partial&space;f(\theta_2)}{\partial&space;\theta}&space;\end{}$

import copy
import torch
import torch.nn as nn

from learn2learn.utils import clone_module, update_module

batch_size = 3
lr = 1e-3

# set up weight manually
net = nn.Sequential(nn.Linear(2, 1, False))
net[0].weight = nn.Parameter(torch.tensor([[.1, .2]]))

# dataset = [torch.randn(batch_size, 2) for _ in range(task_size)]
dataset = [torch.Tensor([[1., 2.], [2., 1.], [2., 2.]])]

for data in dataset:
temp_net = clone_module(net)

# first inner-loop
loss_val = nn.functional.mse_loss(net(data), torch.zeros((batch_size, 1)))*3.

print(list(temp_net.parameters())[0])

# second inner-loop
loss_val = nn.functional.mse_loss(temp_net(data), torch.zeros((batch_size, 1)))*3.

print(list(temp_net.parameters())[0])

# outer-loop
loss_val = nn.functional.mse_loss(temp_net(data), torch.zeros((batch_size, 1)))*3.

# take average of meta_grads and update_module on net


I am not yet checking the computation. I’ll get back to this when I finish checking it.

Edit:

• I changed the code a little bit for correctness + using pre-determined weight and inputs so I can compare to manually computed result.
• I move the print function and the lr*grad so the code is more readable.
1 Like

The code is indeed correct. I compare it with the following analytic computation.

import torch
import torch.nn as nn

# (constant) hessian
hessian = torch.tensor([[18., 16.], [16., 18.]])

# set up weight manually
net = nn.Sequential(nn.Linear(2, 1, False))
net[0].weight = nn.Parameter(torch.tensor([[.1, .2]]))

lr_ = 1e-3

## GRADIENT COMPUTATION USING CLOSED FORM SOLUTION
fast_weight_in_1_analytic = net[0].weight - lr_ * grad_in_1_analytic

print("Parameter:", fast_weight_in_1_analytic)
print()

fast_weight_in_2_analytic = fast_weight_in_1_analytic - lr_ * grad_in_2_analytic

print("Parameter:", fast_weight_in_2_analytic)
print()

grad_out_analytic   = grad_in_3_analytic @ (torch.eye(2) - lr_ * hessian) @ (torch.eye(2) - lr_ * hessian)



To answer your main question, this solution does not store computation graph for all tasks. Instead, it constructs computation graph for one task, compute the meta-gradient, store the meta-gradient in a list that can be used (i.e. averaged) to meta-update the main network, and then repeat.

2 Likes

Hi, yes I think we did.
Could you elaborate a little on your equations + the solution?
In particular the second equation is not something I would not expect to see in original MAML, unless it were maybe part of the inner loop. Also do you know if you’re code will work with data parallel training setup (across GPUs)?

Thanks, I really appreciate your help with this.

Hi, those equations are an example of how you do MAML with two steps of gradient descent in one task. f is the loss function. The first two lines are the two steps of gradient descent for the task objective. The third line is the gradient descent for the meta objective.

To compare it with MAML in Algorithm 1 on its paper,

• The first two lines are the Line 6 of the Algorithm 1 (because we do two steps of gradient descent just so you can imagine if you want to implement general “n-steps of inner gradient descent” version of MAML).
• The third line is the Line 8 of the Algorithm 1. However, in the code, I stop at only computing the df(theta_2)/dtheta and store that value in a list. This list will contain the df(theta_2)/dtheta of all tasks at the end. So, after exiting the for data in dataset loop, I can average that list and fo the meta-update as in the Line 8. The list of df(theta_2)/dtheta is like the copy of the model you are asking for.

For example is it possible/easy to store a copy of the model where you zero all parameters then apply the weight updates from the real model to that model one task at a time.

Feel free to ask/discuss further if you still have doubts as my answer might not fully resolve your question.

Regarding parallelization across multiple GPUs, I am not sure on how to implement it. I think it can be done efficiently by computing the meta-gradient of task on separate GPUs. Then each gradients computed on separate GPUs are sent to the main-GPU to be averaged and then used to update the model. Then, the model on each separate GPUs is updated by the main-GPU. Basically it’s like a synchronous update. (Sorry for the bad explanation, I have little experiences on this so I don’t know much about the terminology)

Ha, I found this article! We can do data parallelism + centarlized synchronous update.

Ah I see so it was part of the inner loop. That makes sense.

Unfortunately I can’t read that article because I don’t subscribe to medium.
I’m using nn.DataParallel() currently. I was just asking if it works when you aren’t using backwards()? I guess maybe you’d use this? torch.distributed.autograd

P.S. do you know how many subtasks you need for meta-learning to be viable?

Thanks for update and quick reply, it’s work for me, looking for same concern.