This is a problem related to higher gradients.

In this simple case:

```
a = torch.tensor([1., 2., 3.], requires_grad=True)
b = torch.tensor([1., 2., 3.], requires_grad=True)
c = a * b
loss = c.sum()
loss.backward(create_graph=True)**
c = a - a.grad
loss2 = c.sum()
```

Here, **b** is the tensor I need to optimize, and **a** is model parameter. In the above codes, **c** is the updated version of tensor **a**, **loss2** is a loss which could backward to tensor **b**. Howerver, the computation graph from **a** to **loss2** may be complex as **a** can be a parameter of a huge model.

```
a = torch.tensor([1., 2., 3.], requires_grad=True)
b = torch.tensor([1., 2., 3.], requires_grad=True)
opt = SGD([a], lr=1.)
c = a*b
loss = c.sum()
loss.backward(create_graph=True)
opt.step()
loss2 = a.sum()
```

In this version, **loss2** could not backward to tensor **b**. Is there a simple way to implement my needs, i.e., making the updated paramater **a** as an interim tensor induced by **b**, and **a** can be used further to compute **loss2**.