Performing backward on another backward

Hi folks! I ran into the following issue. I was wondering if someone could help me out.

I have a loss l(w,theta), where w is weights and theta is network parameters. I need to first calculate theta.hat using S.G.D. as follows.

theta.hat = theta - alpha * grad l(w,theta).

Then, I need to use the updated theta.hat in the same network and calculate another loss L(theta.hat). As theta.hat is a function of w, so is L. My question is, how can I get the gradient of L with respect to w?

All my attempts had the following error: One of the differentiated Tensors appears to not have been used in the graph. So I guess w was not really in my graph.

For example, I tried using optimizer.step(), but after that, theta.hat got the updated values, but w was lost from the graph. In other words, optimizer.step() was sort of like with torch.no_grad(), whereas I would like something that kept the graph that traced back to w.

Thanks for your help!

Hi,

To do this, you will need to set create_graph=True in the first call to backward !

Thanks for your suggestion. I tried and the same thing happened. Could you take a look at the following example? The goal is to get the gradient of l_val with respect to eps.

def cal_loss(net, x_train, y_train, x_val, y_val):
    
    y_train_hat = net(x_train)
    cost = nn.CrossEntropyLoss(reduction='none')(y_train_hat, y_train)

    # here is eps, not sure whether to set it to nn.Parameters, but either way didn't work
    eps = nn.Parameter(torch.zeros_like(cost), requires_grad=True)
    l_train = torch.sum(cost * eps)
    
    # first backward step
    opt = optim.SGD(net.parameters())
    opt.zero_grad()
    l_train.backward(create_graph=True)
    opt.step()
    
    # use updated network to calculate l_val
    y_val_hat = net(x_val)
    l_val = nn.CrossEntropyLoss()(y_val_hat, y_val)

    # this didn't work
    grad_eps = torch.autograd.grad(l_val, eps)[0]

The last line got an error: One of the differentiated Tensors appears to not have been used in the graph. I suspected opt.step() was the culprit, but I’m not sure exactly which line left eps out of the graph. Any ideas how to fix this? Thanks so much!

Jason

Hi,

Yes one issue is that you are trying to differentiate through the optimizer step. But pytorch optimizers are not differentiable.
You can take a look at libraries like higher on how to go around this.

Thanks I’ll look into that :slight_smile: