when I do backward() to some non-scalar variables $y$, the shape of result is always the same as input $x$.
Is there any method to get a y-shaped result?
e.g.
y = model(x) # x.shape: (B, 1), y.shape: (B, K)
y.backward(torch.ones_like(y))
x.grad.shape == x.shape # (B, 1)
>>> True
But what I want to get is
$$
\frac{\part y}{\part x} = (\frac{\part y_1}{\part x}, \frac{\part y_2}{\part x}, … ,\frac{\part y_k}{\part x})^T
$$
a result of shape (B, K).
Now my solution is to write a for loop to get each $\frac{\part y_i}{\part x}$ , but it is too slow, is there any better way?
x_grads = []
for i in range(y.shape[-1]): # K
x_grad = torch.autograd.grad(y[..., i],
x,
torch.ones_like(y[..., i]),
retain_graph=True)[0]
x_grads.append(t_grad)
x_grads = torch.cat(x_grads, dim=-1)