Problem with shapes in autograd.grad

Hi. I’m trying to compute a gradient of a non-scalar tensor with torch.autograd.grad. The input shape is (B, Ns, 3) and the output shape is (B, Ns, Nf). I use the following code to compute it:

gradients = torch.autograd.grad(
        outputs=sigma,
        inputs= xyz,
        grad_outputs=torch.ones_like(sigma),
    )[0]

The problem is that the gradients shape is (B, Ns, 3) when I was expecting (B, Ns, Nf, 3). So far I tried to compute de gradients for each element Nf of the output ‘sigma’ separately and it seems to work, but is there a way to compute it directly and more efficiently by using just autograd.grad?

Thanks

1 Like