Hi. I’m trying to compute a gradient of a non-scalar tensor with torch.autograd.grad. The input shape is (B, Ns, 3) and the output shape is (B, Ns, Nf). I use the following code to compute it:
gradients = torch.autograd.grad(
outputs=sigma,
inputs= xyz,
grad_outputs=torch.ones_like(sigma),
)[0]
The problem is that the gradients shape is (B, Ns, 3) when I was expecting (B, Ns, Nf, 3). So far I tried to compute de gradients for each element Nf of the output ‘sigma’ separately and it seems to work, but is there a way to compute it directly and more efficiently by using just autograd.grad?
Thanks