Problem with shapes in autograd.grad

Hi. I’m trying to compute a gradient of a non-scalar tensor with torch.autograd.grad. The input shape is (B, Ns, 3) and the output shape is (B, Ns, Nf). I use the following code to compute it:

gradients = torch.autograd.grad(
        inputs= xyz,

The problem is that the gradients shape is (B, Ns, 3) when I was expecting (B, Ns, Nf, 3). So far I tried to compute de gradients for each element Nf of the output ‘sigma’ separately and it seems to work, but is there a way to compute it directly and more efficiently by using just autograd.grad?


1 Like