# Performing backward on another backward

Hi folks! I ran into the following issue. I was wondering if someone could help me out.

I have a loss `l(w,theta)`, where `w` is weights and `theta` is network parameters. I need to first calculate `theta.hat` using S.G.D. as follows.

``````theta.hat = theta - alpha * grad l(w,theta).
``````

Then, I need to use the updated `theta.hat` in the same network and calculate another loss `L(theta.hat)`. As `theta.hat` is a function of `w`, so is `L`. My question is, how can I get the gradient of `L` with respect to `w`?

All my attempts had the following error: One of the differentiated Tensors appears to not have been used in the graph. So I guess `w` was not really in my graph.

For example, I tried using `optimizer.step()`, but after that, `theta.hat` got the updated values, but `w` was lost from the graph. In other words, `optimizer.step()` was sort of like `with torch.no_grad()`, whereas I would like something that kept the graph that traced back to `w`.

Thanks for your help!

Hi,

To do this, you will need to set `create_graph=True` in the first call to backward !

Thanks for your suggestion. I tried and the same thing happened. Could you take a look at the following example? The goal is to get the gradient of `l_val` with respect to `eps`.

``````def cal_loss(net, x_train, y_train, x_val, y_val):

y_train_hat = net(x_train)
cost = nn.CrossEntropyLoss(reduction='none')(y_train_hat, y_train)

# here is eps, not sure whether to set it to nn.Parameters, but either way didn't work
eps = nn.Parameter(torch.zeros_like(cost), requires_grad=True)
l_train = torch.sum(cost * eps)

# first backward step
opt = optim.SGD(net.parameters())
l_train.backward(create_graph=True)
opt.step()

# use updated network to calculate l_val
y_val_hat = net(x_val)
l_val = nn.CrossEntropyLoss()(y_val_hat, y_val)

# this didn't work
The last line got an error: One of the differentiated Tensors appears to not have been used in the graph. I suspected `opt.step()` was the culprit, but I’m not sure exactly which line left `eps` out of the graph. Any ideas how to fix this? Thanks so much!