Need help computing gradient of the output with respect to the input

Hi! I have a trained model and now I would like to compute the gradient of the output with respect to the inputs. Based on this post python - Getting the output's grad with respect to the input - Stack Overflow I am doing it like this:

inputs.require_grad_()

outputs=model(preprocessing(inputs))

gradients= [torch.autograd.grad(inputs=inputs, outputs=outputs[i], allow_unused=True, retain_graph=True) for i in range(len(outputs))]

But it is extremely slow and I think I may be doing something wrong for it to be that slow.
Is there something I am missing?
Thanks in advance

I explained it really poorly. The slow part is not the grad computation itself, but the fact that, as I only get the gradient of one output per function call, I need to call the function one time per output. That is what makes it slow.
In the end, I saw that I could simply pass a grad_outputs tensor as follows

torch.autograd.grad(outputs=output,inputs=input,grad_outputs=torch.ones_like(output),retain_graph=True)

This computes the gradient of all the outputs (with respect to their respective input) or so I think. Am I right?

Hi Malfonsoarquimea!

No, by using grad_outputs = torch.ones_like (output) you end up
computing the inner product (dot product) of the gradients of output with
the torch.ones_like() vector. This does not preserve the gradient of
each output as a separate entity.

At issue is the fact that, as currently implemented, autograd computes the
gradient of a single scalar result (in any individual backward pass). To get
the gradients of the elements of output separately, you have to run multiple
backward passes in a loop, such as you did in your original post.

And running such a loop over the elements of output does take time for
each iteration of the loop.

This example illustrates these points:

>>> torch.__version__
'1.10.2'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> lin = torch.nn.Linear (1, 1, bias = False)   # a trivial "model"
>>> lin.weight
Parameter containing:
tensor([[-0.2084]], requires_grad=True)
>>> t = (torch.arange (4.) + 1).unsqueeze (-1)
>>> t
tensor([[1.],
        [2.],
        [3.],
        [4.]])
>>>
>>> lin (t)   # output consists of four separate scalars
tensor([[-0.2084],
        [-0.4168],
        [-0.6252],
        [-0.8336]], grad_fn=<MmBackward0>)
>>>
>>> # use loop to compute gradient with respect to each scalar output
... # (for simplicity, this example unnecessarily runs four forward passes)
... for  i in range (4):
...     lin.weight.grad = None
...     lin (t)[i].backward()
...     print ('i:', i, ' gradient:', lin.weight.grad)
...
i: 0  gradient: tensor([[1.]])
i: 1  gradient: tensor([[2.]])
i: 2  gradient: tensor([[3.]])
i: 3  gradient: tensor([[4.]])
>>> # gradient of a single scalar result
... lin.weight.grad = None
>>> lin (t).sum().backward()   # sum over outputs to get a single scalar result
>>> lin.weight.grad
tensor([[10.]])
>>>
>>> # using gradient = torch.ones() still just gives a single gradient
... lin.weight.grad = None
>>> lin (t).backward (gradient = torch.ones (4, 1))
>>> lin.weight.grad
tensor([[10.]])

(This example is packaged to illustrate computing gradients with respect to
the weights of a “layer”, rather than with respect to “inputs,” as this is the
more common use case, but there is no conceptual difference.)

Best.

K. Frank

Hi! thanks for your answer which is very well packaged!
Probably I am wrong but, as far as I have tested, using the code on my original post I would get a tensor of the size of the input where all the items except the one for which I provide the output are 0

>>>inputs
tensor([[-1.0971,  1.1935,  1.2421],
        [-1.0270,  1.1890,  1.2345],
        [-0.9428,  1.1836,  1.2254],
        ...,
        [ 1.6314, -2.7406,  0.6462],
        [ 1.6619, -2.7843,  0.6395],
        [ 1.7008, -2.8399,  0.6311]], device='cuda:0', grad_fn=<ViewBackward0>)
>>>outputs
tensor([[-10.9158],
        [ -2.2752],
        [ -5.6269],
        ...,
        [  2.9010],
        [ 22.6586],
        [ 25.9725]], device='cuda:0', grad_fn=<AddmmBackward0>)
>>>gradients=torch.autograd.grad(inputs=inputs,outputs=outputs[0],allow_unused=True,retain_graph=True)
>>>gradients
(tensor([[  534.2734,...='cuda:0'),)
>>>gradients[0]
tensor([[  534.2734, -2844.8074,   709.5488],
        [    0.0000,     0.0000,     0.0000],
        [    0.0000,     0.0000,     0.0000],
        ...,
        [    0.0000,     0.0000,     0.0000],
        [    0.0000,     0.0000,     0.0000],
        [    0.0000,     0.0000,     0.0000]], device='cuda:0')

However, when doing it using passing the torch.oneslike() as grad_output I get one tensor of the size of the input. I assumed it contained the gradient of each output which respect to their respective input but maybe I am wrong

>>>gradients=torch.autograd.grad(outputs=outputs,inputs=inputs,grad_outputs=torch.ones_like(outputs),retain_graph=True)
>>>gradients
(tensor([[  534.2734,...='cuda:0'),)
>>>gradients[0]
tensor([[  534.2734, -2844.8074,   709.5488],
        [ -247.5381,    85.9924,   537.6161],
        [  533.3944,   128.2310,    29.9466],
        ...,
        [ -621.1329,  -340.0153,   701.6570],
        [  620.2058,  1783.1162,  -927.3855],
        [  387.5083,   450.1457,   -69.2568]], device='cuda:0')

Here I don’t see all the gradients summed as one scalar output, and their values seem to be the same I would obtain by individually computing them following the code on my original post.

As I said, I am not experienced in this matter and probably I am not understanding it correctly, but could you point if there is some misconception I am assuming?
thanks in advance

Hi Malfonsoarquimea!

Here is what is probably going on:

If your model:

outputs=model(preprocessing(inputs))

maps a batch of inputs, say of shape [nBatch, N], to a batch of scalar
outputs of shape [nBatch] (or similar), and each batch element of
outputs depends only on the corresponding element of inputs (that is,
outputs[i] depends only on inputs[i], and not on inputs[j != i]),
then the gradient of the scalar value, outputs.sum() with respect to
inputs[i] will, in fact, be the gradient of outputs[i] with respect to
inputs[i].

All of the zeros you get when you compute the gradient of outputs[i]
with respect to all of the elements of inputs are to be expected.
inputs.grad[i] is non-zero, as outputs[i] depends on inputs[i],
while all of the other inputs.grad[j != i] are zero because (by my
assumption) outputs[i] does not depend on inputs[j != i] – no
dependence, so zero gradient.

The last step, reiterating what I said in my previous post, is that:

torch.autograd.grad (outputs = outputs, inputs = inputs, grad_outputs = torch.ones_like (outputs))

is equal to:

torch.autograd.grad (outputs = outputs.sum(), inputs = inputs, grad_outputs = None)

In short, if your use case only requires that you compute the gradients
of a batch of scalar outputs with respect to a batch of inputs where each
output batch element only depends on the corresponding input batch
element, then you do only need one call to torch.autograd.grad() (or,
if you prefer, outputs.sum().backward()), and you only backpropagate
through the computation graph once.

Best.

K. Frank

Thanks very much for your elaborated answer K. Frank, it is really helpful!
In fact, I am using a NeRF-like network where inputs are XYZ coordinates and the output is the density of the space at that XYZ coordinates. I would like to compute the density (output) gradient at each point in my coordinates (input) batch. So, as far as I understand, it is as you said because the gradient of the density at one point only depends on the coordinates of that point.
As said, thanks very much for your time. Do you have any post you can recommend to learn about how pytorch autograd works in detail?
Kind Regards!

I am revisiting this and still find your explanation super illustrative. Thanks again!