Hello.

Given a function $ R^{3} \mapsto R^{3} $ implemented as *point-wise* MLPs, what is the most efficient way of computing the Jacobian matrix?

In my case, my input and output are of shape `(B, N, 3)`

and the desired Jacobian matrix would be of shape `(B, N, 3, 3)`

as the feedforward function is applied point-wise. I first `view`

as `(B*N, 3)`

and call `torch.autograd.functional.jacobian`

that results in output of shape `(B*N, 3, 3, B*N, 3)`

. If I understand correctly, the gradients are computed between every point in my input and every point in the output. Rightly so, `jacobian.sum(dim=2)`

gives me the desired `(B*N, 3,3)`

as the output. Instead of computing a backward pass between different points (which are 0s), could the jacobian computation be done more efficiently?

I also have come up with an implementation that is faster than `torch.autograd.functional.jacobian`

as shown below. However, Iâ€™m not sure whether the `slicing`

and `cat`

operation applied here breaks the gradient. Practically, I observe that my loss doesnâ€™t change at all and hence was curious what could be going wrong.

```
def compute_jacobian(inp, out):
u = out[:, 0]
v = out[:,1]
w = out[:,2]
grad_outputs = torch.ones_like(u)
grad_u = torch.autograd.grad(u, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_v = torch.autograd.grad(v, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_w = torch.autograd.grad(w, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
return torch.cat((grad_u, grad_v, grad_w), dim=-1).reshape(-1, 3,3)
```

```
def compute_jacobian_batch(inp, out):
u = out[:,:,0]
v = out[:,:,1]
w = out[:,:,2]
n_batch = inp.size(0)
grad_outputs = torch.ones_like(u)
grad_u = torch.autograd.grad(u, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_v = torch.autograd.grad(v, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_w = torch.autograd.grad(w, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
jacob_mat = torch.cat((grad_u, grad_v, grad_w), dim=-1).reshape(n_batch, -1, 3,3).contiguous()
jacob_mat.retain_grad()
return jacob_mat
```

Can someone please verify this implementation ? Thank you!