Gradient computation in meta-learning algorithms


Related question:

I’m trying to implement meta-learning algorithms such as MAML(TF implementation) and Meta-SGD in PyTorch. I have difficulties in understanding the gradient flow.

train_samples, test_samples = task['train'], task['test']

images, labels = train_samples['image'], train_samples['label']
if args.cuda:
    images, labels = images.cuda(), labels.cuda()
images, labels = Variable(images), Variable(labels)

images, labels = test_samples['image'], test_samples['label']
if args.cuda:
    images, labels = images.cuda(), labels.cuda()
images, labels = Variable(images), Variable(labels)

# Inner gradient update

alpha = [Variable(, requires_grad=True) for p in base_learner.parameters()]

meta_optimizer = optim.Adam(base_learner.parameters(), lr=args.update_lr)

output = base_learner(images)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, labels)

for param in base_learner.parameters():
    param.grad.requires_grad = True

base_learner_ = BaseLearner(args.num_classes, args.num_filters)
if args.cuda:


for param_, lr, param in zip(base_learner_.parameters(), alpha, base_learner.parameters()):
    param_ = param - lr*param.grad


# Note: parameters of base_learner_ doesn't seem to get updated here.

# Meta update


output = base_learner_(images)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, labels)

  1. When I set the parameters of base_learner_ from base_learner, that doesn’t seem to be executed. How to do this properly?
  2. The gradient while backpropagating from base_learner_ doesn’t flow to the parameters of base_learner. I know that this is because the parameters of base_learner_ are not set properly based on base_learner. I want to set it such that gradients of base_learner_ flows to parameters of base_learner.

The algorithm I’m trying to implement can be found at the beginning of page 5 here - Meta-SGD: Learning to Learn Quickly for Few Shot Learning. Is this the right approach? What is the pragmatic way of doing it?

More code:

# Model

class ConvModule(nn.Module):
    """ Conv Module """
    def __init__(self, in_channels, out_channels):
        super(ConvModule, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x =
        x = F.relu(x)
        return F.max_pool2d(x, 2)

class BaseLearner(nn.Module):
    """ Simple Conv Net """
    def __init__(self, num_classes, num_filters=64):
        super(BaseLearner, self).__init__()
        self.conv1 = ConvModule(1, num_filters)
        self.conv2 = ConvModule(num_filters, num_filters)
        self.conv3 = ConvModule(num_filters, num_filters)
        self.conv4 = ConvModule(num_filters, num_filters)
        self.fc = nn.Linear(num_filters, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(-1, 64)
        return self.fc(x)

I’m not sure why you can’t manually set the parameters of base_learner_
But regardless, for MAML one problem you have is that in the meta step, you need to compute the loss using the updated weights, but then you need to compute the gradients with respect to the original weights.
To achieve this I just did away with modules, and directly used the functions that allow you to pass in different weight variables as arguments.
I also found this function useful for collecting gradients for the meta step:

Yes I can replace the Conv2d module with the functional form but then how did you deal with the affine parameters in batch norm?

Hello, Have you implemented the algorithm of Meta-SGD successfully yet? we can communicate since I’m working on Meta-SGD also.

yep, It will be useful with functional if it’s a conv-like net, but what should I do when I want to apply a meta update in a recurrent network?

@aitutakiv, I understand how doing away with modules makes it easier, but writing a functional forward for deeper networks is a pain, so is there a way where to code maml or foml using modules ?

Please let me know, any help is appreciated !


Hi, have you implemented the MAML in “module” way. Hope to post a link, thks