import torch
from torch.autograd import Variable
temp_a = torch.rand(1, 4)
a = torch.rand(1, 4)
b = a**2
c = b*2
d = c.mean()
a.copy_(temp_a)
d.backward()

Hi in this toy example, if I replace the original parameters or input of the model (e.g. a) with another set of the parameters (e.g. temp_a). Then I do backward. Is the calculated gradient w.r.t the original input a or the temp_a.

This is a just a toy example that shows my intention. In my implementation,

I cloned the model parameters at the beginning.

And then I do an intermediate step that would forward, backward, and update the model parameters with autograd.

I forward the model again and get a new output. And I tried to get the gradient of the current output w.r.t the original model parameters I cloned. So I copy the original model parameters back to the model just as I did in the toy example. And I will do the backward and get the gradient.

Finally, I will update the original model parameters with the gradient I get in the last step.

Can this work as I expected? Thank you for your help.

For your example, this will not work if you actually make a require gradients. It will show you that you can’t modify a inplace in a differentiable manner. And so, if you change it in a non-differentiable manner with torch.no_grad(), it will compute the gradient wrt the original value of a (the one used in the forward).

For the nn.Module version, note that the optimizer step and model state dict loading are both done in a non-differentiable manner. So these ops will be ignored by the autograd.
When you call backward, you will get the gradient wrt to the Parameter that was used in the forward.

Using .data is the same as “doing it in a non-differentiable way”.
Note that you should never use .data anymore and use with torch.no_grad() to do this as it is much safer (look for side effects of .data on this forum if you want to know more).

.grad() or .backward() will give the same thing here.
It will compute the gradient as if you did not changed the parameters (and because you use .data, will ingore the inplace change in the inplace checks and potentially compute these gradients wrong).

I can’t think of an easy way to do it I’m afraid. Have you looked at things like higher (https://github.com/facebookresearch/higher) that might be doing what you want in a cleaner way no?

Hi thank you. I found a method in this question. [resolved] Implementing MAML in PyTorch They use the the modules such as functional.linear which has portable weights to achieve this effect. It is a little bit complicated because I have to manually modify the weight. But it seems to work.