Hello,
I am currently implementing the MAML algorithm and encountering issues with computing the outer loop gradients, which should be with respect to the initial model parameters. However, since the output is derived from the updated weights (computed in the inner loop), the original model parameters are not part of the computation graph. As a result, I am unable to calculate the correct gradients for the outer loop. Getting Error:
‘’’
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
‘’’
Any guidance or suggestions on how to resolve this issue would be greatly appreciated. Thank you!
Attaching code for reference:
meta_grad = copy.deepcopy(meta_grad_init)
loss_pre = []
loss_after = []
loss_q_all = []
keep_weight = deepcopy(model.state_dict())
weight_name = list(keep_weight.keys())
weight_len = len(keep_weight)
fast_weights = OrderedDict()
for i in range(args.tasks_per_metaupdate): #batch #32, task 32, user 32
weight_for_local_update = {k: v.clone().detach().requires_grad_(True) for k, v in model.state_dict().items()}
loss_pre.append(F.mse_loss(functional_call(model, fast_weights, x_qry[i]), y_qry[i]).item())
##Inner Loop##
for k in range(args.num_grad_steps_inner):
logits = functional_call(model, weight_for_local_update, x_spt[i])
loss = F.mse_loss(logits, y_spt[i])
model.zero_grad()
grad = torch.autograd.grad(loss, weight_for_local_update.values())
for i_inner in range(weight_len):
if weight_name[i_inner] in local_update_target_weight_name:
fast_weights[weight_name[i_inner]] = list(weight_for_local_update.values())[i_inner] - args.lr_inner * grad[i_inner] # time()
else:
fast_weights[weight_name[i_inner]] = list(weight_for_local_update.values())[i_inner]
weight_for_local_update = fast_weights
##Outer loop##
logits_q = functional_call(model, fast_weights, x_qry[i])
loss_q = F.mse_loss(logits_q, y_qry[i])
task_grad_test = torch.autograd.grad(loss_q, model.parameters())
loss_after.append(loss_q.item())
for g in range(len(task_grad_test)):
meta_grad[g] += task_grad_test[g].detach()
meta_optimiser.zero_grad()
for c, param in enumerate(model.parameters()):
param.grad = meta_grad[c] / float(args.tasks_per_metaupdate)
param.grad.data.clamp_(-10, 10)
meta_optimiser.step()