I am trying to implement this paper for a generic network. I have found an implementation that works, but it requires a definition of the networks in other custom designed class module (MetaModule), which is not too scalable.
The algorithm that I am trying to implement is the following:
And the problem I am encountering is the following:
Code:
#meta_net is created using the same network structure of a model and the params are copied over
# Line 4
y_train_meta_hat = meta_net(x_train)
# Line 5: Compute loss function using our prediction
cost = loss_fn(y_train_meta_hat, y_train, reduction='none')
eps = Variable(torch.zeros(cost.size()), requires_grad=True)
# Line 6: get gradient with respect to parameters
l_f = torch.sum(cost * eps)
meta_net.zero_grad()
grads = torch.autograd.grad(l_f, (meta_net.parameters()), retain_graph=True, create_graph=True)
# Line 7: update theta values of the network
for tgt, src in zip(meta_net.named_parameters(), grads):
tgt.param.data -= learning_rate * src
# Line 8: compute the value of validation data using updated theta values
y_val_hat = meta_net(x_val)
# line 9-10: compute loss function, and compute gradient with respect to eps
l_g = loss_fn(y_val_hat, y_val, reduction='mean')
grad_eps = grad(l_g, eps) ### this does not work
Specifically, I am having trouble with the last line, taking gradient of l_g with respect to eps. The gradient computation graph seems detached the moment I update the parameters using first gradient ( tgt.param.data -= learning_rate * src). To be more specific, I get the following error:
“One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.”
Could anyone give some guidance on how to resolve this issue?