# Derivative of vector valued functions

Hello.

Given a function \$ R^{3} \mapsto R^{3} \$ implemented as point-wise MLPs, what is the most efficient way of computing the Jacobian matrix?

In my case, my input and output are of shape `(B, N, 3)` and the desired Jacobian matrix would be of shape `(B, N, 3, 3)` as the feedforward function is applied point-wise. I first `view` as `(B*N, 3)` and call `torch.autograd.functional.jacobian` that results in output of shape `(B*N, 3, 3, B*N, 3)`. If I understand correctly, the gradients are computed between every point in my input and every point in the output. Rightly so, `jacobian.sum(dim=2)` gives me the desired `(B*N, 3,3)` as the output. Instead of computing a backward pass between different points (which are 0s), could the jacobian computation be done more efficiently?

I also have come up with an implementation that is faster than `torch.autograd.functional.jacobian` as shown below. However, Iâ€™m not sure whether the `slicing` and `cat` operation applied here breaks the gradient. Practically, I observe that my loss doesnâ€™t change at all and hence was curious what could be going wrong.

``````def compute_jacobian(inp, out):
u = out[:, 0]
v = out[:,1]
w = out[:,2]
``````
``````def compute_jacobian_batch(inp, out):
u = out[:,:,0]
v = out[:,:,1]
w = out[:,:,2]
n_batch = inp.size(0)
return jacob_mat
``````

Can someone please verify this implementation ? Thank you!

Hi,

Is doing the sum of the big jacobian what you want? Or do you actually want the diagonal elements there (batch i only take into account gradients for that one batch and not the others)?
Or is your network structure guaranteeing that batches do not â€śinteractâ€ť?

Hello,

No, the opposite. if my input is (N,3) and output will be (N,3). Since my network applies point-wise convolution, there is no interaction between ith and (i+1)st point. So I would want a â€śpoint-wiseâ€ť derivative. To that end, Jacobian that pytorch computes (of size (N,3,N,3) is an overkill as most of elements will be 0 due to no interaction). The diagonal is an underestimate as it doesnâ€™t compute partial derivative.

Or is your network structure guaranteeing that batches do not â€śinteractâ€ť?

Indeed!

I meant diagonal in a generalized way here. But yes, the summation gives you what you want as all the off-diagonal elements will be 0.

Under that assumption, you can indeed do things in a faster way. I think the simplest is to expand your parameters to have a batch dimension. And do a regular backward with ones() like you do. And your weightsâ€™s gradients will be what youâ€™re looking for.

1 Like