I was going through the example (https://github.com/facebookresearch/higher/blob/master/examples/maml-omniglot.py) and it used:
with higher.innerloop_ctx( net, inner_opt, copy_initial_weights=False) as (fnet, diffopt):
why does MAML require that?
Simplified code snippet:
def train(db, net, device, meta_opt, epoch, log): net.train() n_train_iter = db.x_train.shape // db.batchsz for batch_idx in range(n_train_iter): start_time = time.time() # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() task_num, setsz, c_, h, w = x_spt.size() # Initialize the inner optimizer to adapt the parameters to # the support set. n_inner_iter = 5 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) meta_opt.zero_grad() for i in range(task_num): with higher.innerloop_ctx( net, inner_opt, copy_initial_weights=False) as (fnet, diffopt): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. # higher is able to automatically keep copies of # your network's parameters as they are being updated. for _ in range(n_inner_iter): spt_logits = fnet(x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) diffopt.step(spt_loss) # The final set of adapted parameters will induce some # final loss and accuracy on the query dataset. # These will be used to update the model's meta-parameters. qry_logits = fnet(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) # Update the model's meta-parameters to optimize the query # losses across all of the tasks sampled in this batch. # This unrolls through the gradient steps. qry_loss.backward() meta_opt.step()