I am implementing a meta-learning model that has meta-train and meta-test steps. For the meta-train, I optimize the parameters by

optimizer = optim.SGD(Model.parameters(),lr=1e-3, momentum=0.9, weight_decay = 5e-4)
grads_Model = torch.autograd.grad(meta_train_loss, Model.parameters(), ???#1)
fast_weights_Model = Model.state_dict()
adapted_params = OrderedDict()
for (key, val), grad in zip(Model.named_parameters(), grads_Model):
adapted_params[key] = val - meta_step_size * grad
fast_weights_Model[key] = adapted_params[key]

For the meta-test phase, I calculate some other losses and add them together along with the meta-train loss to get a total loss and want to optimize it w.r.t the initial parameters:

My question is how should I use create_graph=true and retain_graph=True for ???#1 and ???#2 positions to make it right. If I do create_graph=true for ???#1 and leave ???#2 blank I get an error that says I should retrain_graph=true for ???#2. If I do retain_graph=true for ???#1 and leave ???#2 blank it works, and if I do create_graph=true for ???#2 is works with retain_graph=true for ???#1.
I am confused and cannot figure out the correct form.
To me, we are already making the adaptation on parameters in the meta-train and we just need to keep the graph so that we can backward through it in the meta_test step using the total loss and w.r.t to the initial parameters Model.parameters(). Am I correct or missing something here?

I have a similar doubt. What I have found so far is that nn.Parameters are LEAF nodes with no history. So, the Gradient Descent (GD) operations that you are trying to perform for the inner loop (adapted_params[key] = val - meta_step_size * grad) won’t be recorded in the computation graph.

The following may help you understand why updating model parameters like the way you have won’t work in PyTorch in more detail:

In fact, you may want to use the package “HIGHER” or refer to the following codes for a workaround (look for MetaModule in the repositories):

If you go through the neat hack of using MetaModule instead of the nn.Module provided by PyTorch as mentioned in my first response to your question, you will be able to write and run your meta-learning algorithms.

Hopefully, this will help you understand how to find a work-around to implement meta-learning algorithms without having to use packages such as higher, pytorch-meta, etc.