Another way to implement MAML(Model-Agnostic Meta-Learning)?

Typically many maml implementation use nn.functional interface such as F.linear(x,w,b). However, it is unfriendly for the pre-defined nn.Module model.
So, if I initialize a model (Resnet as an example), copy its parameters

model = Resnet()
param_copy = copy.deepcopy(model.parameters())

Then, I do the meta training, each update step as follow

# load parameters
for meta_param, model_param in zip(param_copy, model.parameters()):
    # shallow copy here
    model_param = meta_param

out = model(x)
loss = F.cross_entropy(out, y)
grad = torch.autograd.grad(loss, param_copy)
# use grad to update param_copy
param_copy = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, param_copy)))

Thus, in this fashion, nn.functional interface is replaced by shallow copy of the parameter variable.

Hey, just to confirm my understanding of the issue:

  • you are creating a new set of the parameters after updating them with grad out of place
  • nn.functional is easy to use because you can just pass the new parameter
  • however, it is harder to replace the parameter for pre-defined nn.Modules

Please correct me if I’m misunderstanding!

  1. Yes
  2. Yes.
  3. Yes.

These might help, depending if you want to use functorch or not:

FYI: These APIs do very similar things, we’re in the process of moving functorch in to pytorch and merging these two APIs.

Thanks! This is what I need.