Compute gradient wrt different weights

How can I compute the gradient w.r.t to weights different from the one used to compute the loss.

To clarify my problem: I have a neural network model and its weights W1. I train my model on some data and so the weights changed and are now W2. Then I compute the loss L using W2 on a new set of data. Now I would like to update my weight W1 using gradient computed w.r.t to W1 on L (which was computed using W2).

I tried to reload the weights W1 using load_state_dict between the compute of the loss and loss.backward() but it does not work. I think that the problem is that load_state_dict change some internal states and variables in addition to the weights.

How can I solve my problem ?
Thank you

1 Like


Keep in mind that state_dict() does not copy your weights. So after you update your model to contain W2, your state dict also contains W2.
You might want to do it by hand with saved_params = [p.clone() for p in model.parameters()] and load it with

with torch.no_grad():
  for p, old_val in zip(model.parameters(), saved_params):

Thank you or the explanation of state_dict.

However, I found this thread Understanding load_state_dict() effect on computational graph where they explained that my solution is not the good one. But I do not understand how to use nn.functional interface to change the parameters w.r.t which the gradients will be computed during the backward pass.
Can you help me ?

The custom reloading of the weights above should do the right thing for you.
You don’t want to perform forward with custom weights here, you want to snapshot and reload the weights right?

I want to compute the gradient of a loss computed using W2, w.r.t W1. Where W2 is the current state of the network and W1 a previous state.

I’m not sure to understand then.
Do you want the gradient computation to take into account the first forward/backward, the optimizer step and the second forward?
Or you just want to use gradients computes with W2 and apply them to W1?

Sorry, it was not clear.

W1 and W2 are both weights of the neural network but at two distinct moment.
L is the loss function
grad(W, x) is the gradient of x w.r.t. W
net(W, x) is the output of the neural network for x as input when the network is in state W.

I want to compute grad(W1, L(net(W2,x))) where x corresponds to training data.

How do you define W2 here?
Do you have W2 = grad_step(W1, grad(W1, L(net(W1, x))) or W2 = grad_step(W1, grad(W1, L(net(W1, x))).detach()?

W2 is defined as W2 = grad_step(W1, grad(W1, L(net(W1, x))).

So this is going to be a bit more tricky.
The first backward with W1 will need to set create_graph=True.
But then you won’t be able to use optimizer to do the step as they operate inplace and without registering their operation to the autograd.
So you either need to reimplement the optimizer to do this operation in a differentiable manner. And assign these new values without deleting the old ones and keeping gradients flowing (which can be tricky).
Or check libraries such as higher that already do similar things :smiley:

1 Like