Hi Malfonsoarquimea!
No, by using grad_outputs = torch.ones_like (output)
you end up
computing the inner product (dot product) of the gradients of output
with
the torch.ones_like()
vector. This does not preserve the gradient of
each output as a separate entity.
At issue is the fact that, as currently implemented, autograd computes the
gradient of a single scalar result (in any individual backward pass). To get
the gradients of the elements of output
separately, you have to run multiple
backward passes in a loop, such as you did in your original post.
And running such a loop over the elements of output
does take time for
each iteration of the loop.
This example illustrates these points:
>>> torch.__version__
'1.10.2'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> lin = torch.nn.Linear (1, 1, bias = False) # a trivial "model"
>>> lin.weight
Parameter containing:
tensor([[-0.2084]], requires_grad=True)
>>> t = (torch.arange (4.) + 1).unsqueeze (-1)
>>> t
tensor([[1.],
[2.],
[3.],
[4.]])
>>>
>>> lin (t) # output consists of four separate scalars
tensor([[-0.2084],
[-0.4168],
[-0.6252],
[-0.8336]], grad_fn=<MmBackward0>)
>>>
>>> # use loop to compute gradient with respect to each scalar output
... # (for simplicity, this example unnecessarily runs four forward passes)
... for i in range (4):
... lin.weight.grad = None
... lin (t)[i].backward()
... print ('i:', i, ' gradient:', lin.weight.grad)
...
i: 0 gradient: tensor([[1.]])
i: 1 gradient: tensor([[2.]])
i: 2 gradient: tensor([[3.]])
i: 3 gradient: tensor([[4.]])
>>> # gradient of a single scalar result
... lin.weight.grad = None
>>> lin (t).sum().backward() # sum over outputs to get a single scalar result
>>> lin.weight.grad
tensor([[10.]])
>>>
>>> # using gradient = torch.ones() still just gives a single gradient
... lin.weight.grad = None
>>> lin (t).backward (gradient = torch.ones (4, 1))
>>> lin.weight.grad
tensor([[10.]])
(This example is packaged to illustrate computing gradients with respect to
the weights of a “layer”, rather than with respect to “inputs,” as this is the
more common use case, but there is no conceptual difference.)
Best.
K. Frank