Gradient of loss (that depends on gradient of network) with respect to parameters

I’m trying to compute the gradient of my loss function with respect to my model parameters in PyTorch.

That is, let u(x; θ) be the model, where x is the input (in R^n) and θ are the model parameters. I’m trying to compute du/dθ.

For a “simple” loss function, this is not a problem, but my loss function depends on the gradient of the model with respect to its inputs (i.e., du/dx). When I attempt to do this, I’m met with the following error message: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Here is a minimal example to illustrate the issue:

import torch
import torch.nn as nn
from torch.autograd import grad

model = nn.Sequential(nn.Linear(1, 10), nn.Tanh(), nn.Linear(10, 1))

def loss1(x, u):
    return torch.mean(u)

def loss2(x, u):
    d_u_x = grad(u, x, torch.ones_like(u), retain_graph=True, create_graph=True)[0]
    return torch.mean(d_u_x)

x = torch.randn(10, 1)
x.requires_grad_()
u = model(x)

loss = loss2(x, u)
d_loss_params = grad(loss, model.parameters(), retain_graph=True)

If I change the second to last line to read loss = loss1(x, u) things work as expected.

Update: it appears to be working if I set bias=False for the nn.Linears. OK, that makes some sense since the bias is not trainable. But that begs the question, how do I extract only the trainable parameters to use in the gradient computation?

Hi bwehlin!

Roughly speaking, as I don’t have the precise details on this, this is the basic behavior:

When a constant shows up in the computation graph, autograd is smart enough
to no longer track it. This is why you get a tensor that “appears not to have been
used in the graph.”

The second Linear in your Sequential appears at the end. Therefore that Linear’s
bias appears only as an additive constant in the output of u = model (x) u has a
non-zero gradient with respect to the second bias, but it’s a constant, so autograd
doesn’t include it in the graph created for d_u_x.

For example, if you change:

model = nn.Sequential(nn.Linear(1, 10), nn.Tanh(), nn.Linear(10, 1))

to:

model = nn.Sequential(nn.Linear(1, 10), nn.Tanh(), nn.Linear(10, 1), nn.Tanh())

then the output of model will include the second bias passed through the non-linear
function Tanh(), so the gradient with respect to the second bias won’t be a constant
and won’t get removed from the computation graph. This changes model, of course,
but the specific error will go away.

You could instead change:

d_loss_params = grad(loss, model.parameters(), retain_graph=True)

to:

d_loss_params = grad(loss, model.parameters(), retain_graph=True, allow_unused = True)

autograd still removes the constant branch from the computation graph, but now you’re
telling the call to grad() that you expect to be trying to differentiate an “unused” tensor
and that it’s not an error on your part. So autograd proceeds without raising the error.

Yes, as noted above, this is because bias of the second Linear only contributes a
constant to the output of model. (You should be able to leave bias = True for the
first Linear, only setting bias = False for the second Linear, and still get rid of the
error.) Of course, setting bias = False does change the model.

Just to be clear, the bias of a Linear is, in general, trainable. For example, in your
case, the bias of the first Linear is trainable and would be changed by optimizer
step()s.

If by “not trainable” you mean that the second derivative of the output of model with
respect to the second bias is zero, then yes, you could phrase things that way.

As a practical matter, rather than arrange things so that the second bias still contributes
its non-trivial constant value to the output of model, but have it not participate in any
computation graph, just go ahead and use allow_unused = True. autograd will then do
the correct thing (for your use case) – not including (nor computing) the derivative for the
constant tensor in the back propagation and computing an explicit zero for that derivative
both give the same final result. (It’s just that leaving it out is a little cheaper.)

Best.

K. Frank

Hi KFrank,

Thank you for your answer.

The first thing I tried was actually allow_unused=True, but the problem was that I then passed the result of this on to nn.utils.parameters_to_vector which gives AttributeError: 'NoneType' object has no attribute 'device'.

I had somehow assumed allow_unused=True made all my gradients None but now inspecting further, in fact only certain gradients are None.

I’m trying to implement Inverse Dirichlet Weighting [1] and this requires computing the standard deviation of all gradients wrt the model’s parameters. I took another peek at the documentation of autograd.grad just now, and there is an additional option materialize_grads that when turned on sets None gradients to zero in the result. This gives exactly what I want!

So the solution is:

d_loss_params = grad(loss, model.parameters(), retain_graph=True, allow_unused=True, materialize_grads=True)

Thanks

References:
[1] Maddu, S., Sturm, D., Müller, C. L., & Sbalzarini, I. F. (2022). Inverse Dirichlet weighting enables reliable training of physics informed neural networks. Machine Learning: Science and Technology , 3 (1), 015026.