Typically many maml implementation use nn.functional interface such as F.linear(x,w,b). However, it is unfriendly for the pre-defined nn.Module model.
So, if I initialize a model (Resnet as an example), copy its parameters
model = Resnet()
param_copy = copy.deepcopy(model.parameters())
Then, I do the meta training, each update step as follow
# load parameters
for meta_param, model_param in zip(param_copy, model.parameters()):
# shallow copy here
model_param = meta_param
out = model(x)
loss = F.cross_entropy(out, y)
grad = torch.autograd.grad(loss, param_copy)
# use grad to update param_copy
param_copy = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, param_copy)))
Thus, in this fashion, nn.functional interface is replaced by shallow copy of the parameter variable.