Is there a simple way to make the updated paramaters as interim tensors?

This is a problem related to higher gradients.
In this simple case:

a = torch.tensor([1., 2., 3.], requires_grad=True)
b = torch.tensor([1., 2., 3.], requires_grad=True)
c = a * b
loss = c.sum()
loss.backward(create_graph=True)**
c = a - a.grad
loss2 = c.sum()

Here, b is the tensor I need to optimize, and a is model parameter. In the above codes, c is the updated version of tensor a, loss2 is a loss which could backward to tensor b. Howerver, the computation graph from a to loss2 may be complex as a can be a parameter of a huge model.

a = torch.tensor([1., 2., 3.], requires_grad=True)
b = torch.tensor([1., 2., 3.], requires_grad=True)
opt = SGD([a], lr=1.)
c = a*b
loss = c.sum()
loss.backward(create_graph=True)
opt.step()
loss2 = a.sum()

In this version, loss2 could not backward to tensor b. Is there a simple way to implement my needs, i.e., making the updated paramater a as an interim tensor induced by b, and a can be used further to compute loss2.