Gradient computation in meta-learning algorithms


Related question:

I’m trying to implement meta-learning algorithms such as MAML(TF implementation) and Meta-SGD in PyTorch. I have difficulties in understanding the gradient flow.

train_samples, test_samples = task['train'], task['test']

images, labels = train_samples['image'], train_samples['label']
if args.cuda:
    images, labels = images.cuda(), labels.cuda()
images, labels = Variable(images), Variable(labels)

images, labels = test_samples['image'], test_samples['label']
if args.cuda:
    images, labels = images.cuda(), labels.cuda()
images, labels = Variable(images), Variable(labels)

# Inner gradient update

alpha = [Variable(, requires_grad=True) for p in base_learner.parameters()]

meta_optimizer = optim.Adam(base_learner.parameters(), lr=args.update_lr)

output = base_learner(images)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, labels)

for param in base_learner.parameters():
    param.grad.requires_grad = True

base_learner_ = BaseLearner(args.num_classes, args.num_filters)
if args.cuda:


for param_, lr, param in zip(base_learner_.parameters(), alpha, base_learner.parameters()):
    param_ = param - lr*param.grad


# Note: parameters of base_learner_ doesn't seem to get updated here.

# Meta update


output = base_learner_(images)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, labels)

  1. When I set the parameters of base_learner_ from base_learner, that doesn’t seem to be executed. How to do this properly?
  2. The gradient while backpropagating from base_learner_ doesn’t flow to the parameters of base_learner. I know that this is because the parameters of base_learner_ are not set properly based on base_learner. I want to set it such that gradients of base_learner_ flows to parameters of base_learner.

The algorithm I’m trying to implement can be found at the beginning of page 5 here - Meta-SGD: Learning to Learn Quickly for Few Shot Learning. Is this the right approach? What is the pragmatic way of doing it?

More code:

# Model

class ConvModule(nn.Module):
    """ Conv Module """
    def __init__(self, in_channels, out_channels):
        super(ConvModule, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x =
        x = F.relu(x)
        return F.max_pool2d(x, 2)

class BaseLearner(nn.Module):
    """ Simple Conv Net """
    def __init__(self, num_classes, num_filters=64):
        super(BaseLearner, self).__init__()
        self.conv1 = ConvModule(1, num_filters)
        self.conv2 = ConvModule(num_filters, num_filters)
        self.conv3 = ConvModule(num_filters, num_filters)
        self.conv4 = ConvModule(num_filters, num_filters)
        self.fc = nn.Linear(num_filters, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(-1, 64)
        return self.fc(x)
1 Like

I’m not sure why you can’t manually set the parameters of base_learner_
But regardless, for MAML one problem you have is that in the meta step, you need to compute the loss using the updated weights, but then you need to compute the gradients with respect to the original weights.
To achieve this I just did away with modules, and directly used the functions that allow you to pass in different weight variables as arguments.
I also found this function useful for collecting gradients for the meta step:

1 Like

Yes I can replace the Conv2d module with the functional form but then how did you deal with the affine parameters in batch norm?

Hello, Have you implemented the algorithm of Meta-SGD successfully yet? we can communicate since I’m working on Meta-SGD also.

yep, It will be useful with functional if it’s a conv-like net, but what should I do when I want to apply a meta update in a recurrent network?

@aitutakiv, I understand how doing away with modules makes it easier, but writing a functional forward for deeper networks is a pain, so is there a way where to code maml or foml using modules ?

Please let me know, any help is appreciated !


Hi, have you implemented the MAML in “module” way. Hope to post a link, thks

Hello @aitutakiv, I am a bit confused. How should I compute the gradients wrt to the original weights? If it’s possible, can you review my code. It seems I am approaching the problem quite naively, and missing out on something important! It would be great if you are able to give me a few pointers :slight_smile:

My inner loop code:

def inner_loop(learner, batch, loss_fn, steps, device):
    X, y = batch # Batch Size = 2*shots*ways
    X, y =,

    # Separate data into adaptation/evalutation sets
    # Dtrain and Dtest both of size ways*shots
    adapt_idx = np.zeros(X.size(0), dtype=bool)
    adapt_idx[np.arange(1*5) * 2] = True # shots*ways
    eval_idx = torch.from_numpy(~adapt_idx)
    apt_idx = torch.from_numpy(adapt_idx)
    adapt_data, adapt_labels = X[adapt_idx], y[adapt_idx]
    eval_data, eval_labels = X[eval_idx], y[eval_idx]

    # Adaptation Steps (Steps: 0 to K-1)
    for steps in range(0, steps): # steps : no of adaption steps
        pred = learner(adapt_data)
        loss = loss_fn(pred, adapt_labels)
        grad = torch.autograd.grad(loss, learner.parameters())
        new_weight = list(map(lambda p: p[1] - 0.4*p[0], zip(grad, learner.parameters()))) # theta' = theta - alpha*grads
        with torch.no_grad(): # Updating the new_weights to learner
            for i, params in enumerate(learner.parameters()):
    # Evaluation loss, accuracy
    pred_val = learner(eval_data)
    loss_val = loss_fn(pred_val, eval_labels)
    pred_val = pred_val.argmax(dim=1).view(eval_labels.shape)
    acc_val = (pred_val == eval_labels).sum().float()/eval_labels.size(0)
    return loss_val, acc_val

And the outer loop:

iterations = 5
meta_batch_size = 32
for iteration in range(iterations):
    train_loss, train_acc = 0, 0
    for task in range(meta_batch_size):
        # Copy model parameters. 
        learner = deepcopy(new_model)
        batch = tasksets.train.sample()  # Generates 2*shots*ways samples.
        # The batch will be split into Dtrain and Dtest in the inner loop (both of size shots*ways)
        loss_val, acc_val = inner_loop(learner, batch, loss_fn, 1, device) # Gradient steps: 1
        train_loss += loss_val/meta_batch_size
        train_acc += acc_val/meta_batch_size
        del learner
    if iteration % 1 == 0:
        print('Training Accuracy:', train_acc.item())
        print('Training Loss:', train_loss.item())

As you might have guessed, the computation graph is breaking in the outer loop, thus train_loss is not updating any weights. It would be awesome if you are able to point out a few things! Thank you!

Just for the sake of additional information, I will add the Train Loss, Acc here:

Training Accuracy: 0.3437499403953552
Training Loss: 1.4747799634933472
Training Accuracy: 0.3062499761581421
Training Loss: 1.510756254196167
Training Accuracy: 0.38749998807907104
Training Loss: 1.5028401613235474
Training Accuracy: 0.3312499225139618
Training Loss: 1.4998550415039062
Training Accuracy: 0.35625001788139343
Training Loss: 1.477844476699829
1 Like