How to manually update network parameters while keeping track of its computational graph?

I am trying to implement MAML. With simplification, these are the three operations I wish to implement where theta is the weights of a neural network.

Assumes constant input i.e. f(theta) = loss(net(in_)) because we only interested in the gradients w.r.t. weights.

Here’s my code snippet

temp_net = copy.deepcopy(net)

## First formula
loss_1 = loss(net(in_))
grad_1 = torch.autograd.grad(loss_1, net.parameters(), create_graph=True)

with torch.no_grad():
    for param, grad in zip(temp_net.parameters(), grad_1):
        new_param = param - lr * grad
        param.copy_(new_param)

## Second formula
loss_2 = loss(temp_net(in_))
grad_2 = torch.autograd.grad(loss_2, temp_net.parameters(), create_graph=True)

with torch.no_grad():
    for param, grad in zip(temp_net.parameters(), grad_2):
        new_param = param - lr * grad
        param.copy_(new_param)

## Third formula
loss_3 = loss(temp_net(in_))
grad_3 = torch.autograd.grad(loss_3, net.parameters())

Last line throws an error because the computational graph disconnects when I use with torch.no_grad(). I can’t compute gradient w.r.t. net.parameters(). However, if I remove the torch.no_grad(), it throws an error RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

My question is how can I implement the given formula? How can I update my network parameters manually while maintaining its computational graph?

Seeing from learn2learn implementation of MAML, turns out we can manually do in-place update of the parameters (not using .copy_()) to maintain the computational graph. However, it requires us to recursively check the ._parameters property of each _modules of our model. This is the manual in-place update and this is the recursive part.