I am trying to implement MAML. With simplification, these are the three operations I wish to implement where theta
is the weights of a neural network.
Assumes constant input i.e. f(theta) = loss(net(in_))
because we only interested in the gradients w.r.t. weights.
Here’s my code snippet
temp_net = copy.deepcopy(net)
## First formula
loss_1 = loss(net(in_))
grad_1 = torch.autograd.grad(loss_1, net.parameters(), create_graph=True)
with torch.no_grad():
for param, grad in zip(temp_net.parameters(), grad_1):
new_param = param - lr * grad
param.copy_(new_param)
## Second formula
loss_2 = loss(temp_net(in_))
grad_2 = torch.autograd.grad(loss_2, temp_net.parameters(), create_graph=True)
with torch.no_grad():
for param, grad in zip(temp_net.parameters(), grad_2):
new_param = param - lr * grad
param.copy_(new_param)
## Third formula
loss_3 = loss(temp_net(in_))
grad_3 = torch.autograd.grad(loss_3, net.parameters())
Last line throws an error because the computational graph disconnects when I use with torch.no_grad()
. I can’t compute gradient w.r.t. net.parameters()
. However, if I remove the torch.no_grad()
, it throws an error RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
My question is how can I implement the given formula? How can I update my network parameters manually while maintaining its computational graph?