Gradient with respect to parameters that update model parameters?

I am trying to implement this paper for a generic network. I have found an implementation that works, but it requires a definition of the networks in other custom designed class module (MetaModule), which is not too scalable.

The algorithm that I am trying to implement is the following:

And the problem I am encountering is the following:

#meta_net is created using the same network structure of a  model and the params are copied over

# Line 4
y_train_meta_hat = meta_net(x_train)

# Line 5: Compute loss function using our prediction
cost = loss_fn(y_train_meta_hat, y_train, reduction='none')
eps = Variable(torch.zeros(cost.size()), requires_grad=True)

# Line 6: get gradient with respect to parameters
l_f = torch.sum(cost * eps)
grads = torch.autograd.grad(l_f, (meta_net.parameters()), retain_graph=True, create_graph=True)

# Line 7: update theta values of the network
for tgt, src in zip(meta_net.named_parameters(), grads): -= learning_rate * src
# Line 8: compute the value of validation data using updated theta values
y_val_hat = meta_net(x_val)

# line 9-10: compute loss function, and compute gradient with respect to eps
l_g = loss_fn(y_val_hat, y_val, reduction='mean')
grad_eps = grad(l_g, eps)   ### this does not work

Specifically, I am having trouble with the last line, taking gradient of l_g with respect to eps. The gradient computation graph seems detached the moment I update the parameters using first gradient ( -= learning_rate * src). To be more specific, I get the following error:
“One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.”

Could anyone give some guidance on how to resolve this issue?

Parameters are designed to be leaf nodes in PyTorch. You should not backpropagate parameters. And you are right, modifying .data of a tensor does not produce a grad_fn.

You can use customized layers with backpropagatable “parameters”. An example is

Personally, I suggest you to use my lib, and use dependents() instead of parameters().

Did you find a way to bypass this problem?

It’s just computing for the gradients after the parameter updates (line 6 and line10). I’m shocked that pytorch does not have a way to implement this algorithm without re-writing the whole nn.Module.