How to calculate 2nd derivative of a likelihood function

Hi,

I’m afraid there isn’t. This is a limitation of automatic differentiation. You can find in this gist a more generic implementation. But as you can see, you need the for-loop as well.

1 Like

I Got it. Thanks a lot for your help.

@albanD I wanted to calculate the second derivative of the loss with respect to each hidden state in an LSTM, but could not find a way to do that. Is there a way to have grad enabled for the registered hooks?

More specifically, I want to do domething of the following form: I have alsready modified the RNN to store all the individual hidden states for every time step. However, on performing the standard torch.autograd.grad function, grad.requires_grad is False

grads_h_i = []
for i in range(len(model.rnn.f_h_list)):
    grad = torch.autograd.grad(loss1, model.rnn.f_h_list[i], retain_graph = True)[0]
    print(grad.requires_grad)
    grads_h_i.append(torch.norm(grad))
grads_tensor = torch.tensor(grads_h_i, requires_grad = True, device = grads_h_i[0].device)
loss2 = torch.norm(grads_tensor)
loss2.backward()

You have to use the create_graph=True flag when you call autograd.grad.

2 Likes

Thank you! This worked :slight_smile:
I was messing up between create_graph, and retain_graph.

@Wei_Deng

did you try using pytorch’s higher library?

https://higher.readthedocs.io/en/latest/toplevel.html?highlight=innerloop_ctx#higher.innerloop_ctx

Hi, I am a little bit confused with this code. For the second autograd.grad line, will the gradient flow also backwardly through the new_t, since the new_t is also a function of theta? It seems to me that there are two gradients flow for theta.
Will this lead to a second order derivative first for t, then for theta?

Hi Alban @albanD ,

I have a very similar question. In below subsection of the code, if new_out = f(new_t), everything remains the same, can I still compute theta’s gradients? Thank you!

which means
new_out = f(new_t)
new_loss = F.mse(new_out, target)
gradtheta = autograd.grad(new_loss, theta) # -->how can I get this when theta is only involved in the 1st half? Must I rewrite gradt or new_t that contains theta?
new_theta = theta - gradtheta

Yes it will flow through new_t as it is the only path to reach theta and t.
You can use GitHub - szagoruyko/pytorchviz: A small package to create visualizations of PyTorch execution graphs to get a visualization of this graph.

You need to specify create_graph=True for the first call to autograd.grad to be able to backprop through a backward pass.

1 Like

Thank you very much, @albanD! I saw that create_graph is already set to True in the 1st chunk of the code:
############################################
gradt = autograd.grad(loss, t, retain_graph=True, create_graph=True) ### for the first call to autograd.grad, create_graph=True*
############################################
Do you mean above?

However, for my own implementation, I’d like to obtain theta’s gradients but failed. When I print out theta’s gradients: theta_grads (None,). The error thrown is:
Traceback (most recent call last):
theta -= learning_rate * theta_grads
TypeError: can’t multiply sequence by non-int of type ‘float’’

Below is my code, very similar to above. Here, f(t) is net (resnet34 in my case), theta is a 10 by 10 matrix, requires_grad=True, data_s and data_g can be substituted by cifar10 or any image dataset. I checked that both out (the output from net and theta) and theta do not have 0 values (they are not empty).

I could not find out why. I wonder if you could have a look and see if there is any possible mistakes? Thank you again!

original_weights = OrderedDict()
for name, param in net.named_parameters():
    if not param.requires_grad:
        print(name)   
    else:
        original_weights[name] = copy.deepcopy(param)
original_weights_keys = tuple(original_weights.keys())

with torch.enable_grad():
    logits = net(data_s)
    out = torch.matmul(F.softmax(logits, dim=1), theta)
    loss_ = F.cross_entropy(out, target_s, reduction='sum')
    print('loss_', loss_)### print: tensor(294.9288, device='cuda:0', grad_fn=NllLossBackward>)
        
    #manually compute gradients for net and update net's weights
    net_grads = torch.autograd.grad(loss_, net.parameters(), retain_graph=True, create_graph=True, allow_unused=True) 
    for param, grad in zip(original_weights_keys, net_grads):                
        if grad is None: 
            print(param)
            continue
        else:
            net.state_dict()[param] -= learning_rate * grad

    new_out = net(data_g)
    new_loss = F.cross_entropy(new_out, target_g, reduction='sum')
    theta_grads = torch.autograd.grad(new_loss, theta, allow_unused=True)
    print('### theta_grads', theta_grads) ###print: theta_grads (None,)

    theta -= learning_rate * theta_grads
    theta.clamp(min=0, max=1.0)
    theta.grad.zero_()

Hi,

The problem is net.state_dict()[param] -= learning_rate * grad I think.
In particular, when you call state_dict(), it returns Tensors that are detached from the actual Tensors in the model. So the update you’re doing here is “hidden” from the autograd.
And so there is no path between new_loss and theta.

If you want to have parameters in your net that “have history”, then, they cannot be nn.Parameters and you need to fix that:
Using utility functions from here: pytorch/utils.py at master · pytorch/pytorch · GitHub

    for param, grad in zip(original_weights_keys, net_grads):                
        if grad is None: 
            print(param)
            continue
        else:
            new_weight = original_weights[param] - learning_rate * grad
            # remove old Parameter
            _del_nested_attr(net, param.split("."))]
            # set the Tensor with history
            _set_nested_attr(net, param.split("."), new_weight)

Note though that after this, net.parameters() will be empty because you removed all the Parameter.
You can add them back again later if you want by using _set_nested_attr(net, param.split("."), nn.Parameter(old_weight))

Awesome, @albanD! Now I can compute its gradients :smiley:!

For below, may I ask if I can use model = copy.deepcopy(net) at the biginning of the code, then net = copy.deepcopy(model) later when I need to restore the model? Would there be a difference?

And in below code, can the gradients map to its corresponding weight matrices (would the order of the gradients the same as the one in net.named_parameters())? Or actually this is not guaranteed?

Thanks!

You can do that to save the model if you want yes. The difference is that you would save all the net while my proposal just saved the params on the side.

And in below code, can the gradients map to its corresponding weight matrices (would the order of the gradients the same as the one in net.named_parameters())? Or actually this is not guaranteed?

The ordering is guaranteed.

1 Like

Thank you Alban! :+1:t2: :+1:t2: :+1:t2:

Hello Alban,

In this case, if theta is a function of “loss”, will gradtheta contain gradient information of the gradient of theta w.r.t. loss?

Thanks!

if theta is a function of “loss”,

theta is used to compute the loss. So that’s not possible in this example :smiley:

Do you have a code sample that shows what you want to do?

1 Like

Sorry I mean if loss is a function of theta, which is exactly the case in this example (theta used to compute loss). So in this case will gradtheta contain the gradient of theta w.r.t. loss? If the answer is yes, is there anyway to avoid that? Thanks a lot!!

When you use autograd.grad(out, inp) it gives you dout/dinp.
The fact that an intermediary variable is called loss is not really relevant.

Keep in mind that autograd.grad is very different from .backward() as it just returns the value, it does not accumulate into .grad field.