Hi there. Let’s say we obtain a set of gradients for a set of weights as follows:

grad = torch.autograd.grad([output], [m for m in self.parameters()],
grad_outputs=[grad_input],
create_graph=True, retain_graph=retain_graph,
only_inputs=True)
for i, m in enumerate(self.parameters()):
m.data = m.data - lr * grad[i]

Let the model that had its weight updated by net. We then perform a forward pass using the weight updated model as net(input). As I’ve retained the graph in the autograd performed above, when I’m performing loss.backward(), I’d like the gradient to flow through the grad_input (backproped from the grad[i]'s) variable mentioned above. Any idea how to achieve this?

Here’s a simple flow of the program:

gradOut = self.netA(input_data)
self.netB.paramUpdate(gradOut) # performs the code snippet provided above
out = self.netB(input_data)
loss = self.criterion(out, label)
self.optimizer.zero_grad()
loss.backward() # I hope that the gradients can flow from netB to netA

Hi, the problem is that you use .data which breaks that history and prevent gradient computation.

In particular, if netB does not have any learnable parameters, you should remove them (del netB.some_mod.weight for example) and then fill it in with the value computed: netB.some_mod.weight = lr * grad[i]

The reason of using .data is because if I don’t use it, I realize the the parameter of netB won’t be updated.

Both netA and netB have learnable parameters. The idea is that when I’m doing a forward pass on netB, its parameter is a composition of netB's original parameter plus a gradient provided by netA. So when I’m doing backprop, I’m hoping that the gradient will update both netB and netA as the provided gradient should be part of the computational graph.

I found that pytorch-meta has a similar approach where they pre-compute the updated parameter and re-insert it into the existing model. Is this the most elegant way to accomplish this?

You will have to differentiate between the leant parameter in netB and the ones that are computed at each forward.
One thing you cal do is rename parameters like weight into weight_orig. Then del the weight. And in the forward part, set the weight as weight_orig + your_udpate.