# Derivative of vector valued functions

Hello.

Given a function \$ R^{3} \mapsto R^{3} \$ implemented as point-wise MLPs, what is the most efficient way of computing the Jacobian matrix?

In my case, my input and output are of shape `(B, N, 3)` and the desired Jacobian matrix would be of shape `(B, N, 3, 3)` as the feedforward function is applied point-wise. I first `view` as `(B*N, 3)` and call `torch.autograd.functional.jacobian` that results in output of shape `(B*N, 3, 3, B*N, 3)`. If I understand correctly, the gradients are computed between every point in my input and every point in the output. Rightly so, `jacobian.sum(dim=2)` gives me the desired `(B*N, 3,3)` as the output. Instead of computing a backward pass between different points (which are 0s), could the jacobian computation be done more efficiently?

I also have come up with an implementation that is faster than `torch.autograd.functional.jacobian` as shown below. However, I’m not sure whether the `slicing` and `cat` operation applied here breaks the gradient. Practically, I observe that my loss doesn’t change at all and hence was curious what could be going wrong.

``````def compute_jacobian(inp, out):
u = out[:, 0]
v = out[:,1]
w = out[:,2]
``````
``````def compute_jacobian_batch(inp, out):
u = out[:,:,0]
v = out[:,:,1]
w = out[:,:,2]
n_batch = inp.size(0)
return jacob_mat
``````

Can someone please verify this implementation ? Thank you!

Hi,

Is doing the sum of the big jacobian what you want? Or do you actually want the diagonal elements there (batch i only take into account gradients for that one batch and not the others)?
Or is your network structure guaranteeing that batches do not “interact”?

Hello,

No, the opposite. if my input is (N,3) and output will be (N,3). Since my network applies point-wise convolution, there is no interaction between ith and (i+1)st point. So I would want a “point-wise” derivative. To that end, Jacobian that pytorch computes (of size (N,3,N,3) is an overkill as most of elements will be 0 due to no interaction). The diagonal is an underestimate as it doesn’t compute partial derivative.

Or is your network structure guaranteeing that batches do not “interact”?

Indeed!

I meant diagonal in a generalized way here. But yes, the summation gives you what you want as all the off-diagonal elements will be 0.

Under that assumption, you can indeed do things in a faster way. I think the simplest is to expand your parameters to have a batch dimension. And do a regular backward with ones() like you do. And your weights’s gradients will be what you’re looking for.

1 Like