Replace the input after forward but before backward?

import torch
from torch.autograd import Variable
temp_a = torch.rand(1, 4)
a = torch.rand(1, 4)
b = a**2
c = b*2
d = c.mean()
a.copy_(temp_a)
d.backward()

Hi in this toy example, if I replace the original parameters or input of the model (e.g. a) with another set of the parameters (e.g. temp_a). Then I do backward. Is the calculated gradient w.r.t the original input a or the temp_a.

This is a just a toy example that shows my intention. In my implementation,

  1. I cloned the model parameters at the beginning.
  2. And then I do an intermediate step that would forward, backward, and update the model parameters with autograd.
  3. I forward the model again and get a new output. And I tried to get the gradient of the current output w.r.t the original model parameters I cloned. So I copy the original model parameters back to the model just as I did in the toy example. And I will do the backward and get the gradient.
  4. Finally, I will update the original model parameters with the gradient I get in the last step.

Can this work as I expected? Thank you for your help.

Hi,

For your example, this will not work if you actually make a require gradients. It will show you that you can’t modify a inplace in a differentiable manner. And so, if you change it in a non-differentiable manner with torch.no_grad(), it will compute the gradient wrt the original value of a (the one used in the forward).

For the nn.Module version, note that the optimizer step and model state dict loading are both done in a non-differentiable manner. So these ops will be ignored by the autograd.
When you call backward, you will get the gradient wrt to the Parameter that was used in the forward.

2 Likes

Hi, thank you for your answer.

I didn’t use load_dict to copy parameters. Instead, I use:

def assign_model_parameters(model, temp_parameters):
    idx = 0
    assert len(temp_parameters) == len(list(model.parameters()))
    for param in model.parameters():
        param.data.copy_(temp_parameters[idx].data)
        idx += 1

And also, if I use torch.autograd.grad instead of backward,

        assign_model_parameters(model, initial_parameters)
        model_parameters = list(model.parameters())
        grads = torch.autograd.grad(second_loss, model_parameters, allow_unused=True)

where second_loss is the loss in the third step aforementioned. And I manually update the gradient,

def update_parameters(temp_parameters, loss, alpha, clip_grad=5.0):
    weights = temp_parameters
    grads = torch.autograd.grad(loss, weights, allow_unused=True)
    direct_clip_grad_norm_(grads, clip_grad)
    new_weights = []
    idx = 0
    for w, g in zip(weights, grads):
        if g is not None:
            #print (g)
            #print (idx)
            new_weights.append(w - alpha * g.detach())
        else:
            new_weights.append(w)
        idx += 1
    assert len(new_weights) == len(weights)
    return new_weights

Can it work as I expected?

Hi,

Using .data is the same as “doing it in a non-differentiable way”.
Note that you should never use .data anymore and use with torch.no_grad() to do this as it is much safer (look for side effects of .data on this forum if you want to know more).

.grad() or .backward() will give the same thing here.
It will compute the gradient as if you did not changed the parameters (and because you use .data, will ingore the inplace change in the inplace checks and potentially compute these gradients wrong).

Thank you. May you give me some suggestion about how to achieve the effect I expected?

I can’t think of an easy way to do it I’m afraid. Have you looked at things like higher (https://github.com/facebookresearch/higher) that might be doing what you want in a cleaner way no?

1 Like

Hi thank you. I found a method in this question. [resolved] Implementing MAML in PyTorch They use the the modules such as functional.linear which has portable weights to achieve this effect. It is a little bit complicated because I have to manually modify the weight. But it seems to work.