Jacobian computation with a single call to grad

Hi there, I played around with the computation of a Jacobian. My goal was to find the fastest possible implementation. Therefore, my idea was to feed all output / input combinations into the grad function all at once instead of iterating over it with a for loop. The corresponding method is called jacobian in the code below. Additionally, I added a function called jacobian2 which is using a for loop to compute one row of the jacobian per iteration.

I am aware of the jacobian function which is available since v1.5, but originally I implemented the jacobian function with libtorch in C++ (where the function is not available yet). Then, I realized that the behavior of my function is strange and I reimplemented it with PyTorch to check if the results are the same. This is my code:

import torch
from torch.autograd import grad

def jacobian(input, output):
    input_dim = input.size(-1)
    output_dim = output.size(-1)
    batch_size = output.size(0)
    grad_output = torch.ones([1])
    
    outs = []
    inputs = []
    grad_outputs = []
    for i in range(output_dim):
        outs.append(output[i:i+1])
        inputs.append(input)
        grad_outputs.append(grad_output)

    gradient = grad(outs, inputs, grad_outputs=grad_outputs,
                    retain_graph=True, create_graph=True, allow_unused=False)

    J = torch.stack(gradient, dim=-2)
    return J

def jacobian2(input, output):
    input_dim = input.size(-1)
    output_dim = output.size(-1)
    batch_size = output.size(0)
    grad_output = torch.ones([1])
    
    outs = []
    inputs = []
    grad_outputs = []
    for i in range(output_dim):
        outs.append(output[i:i+1])
        inputs.append(input)
        grad_outputs.append(grad_output)
        
    J = torch.zeros(output_dim, input_dim)
    for i in range(output_dim):
        gradient = grad([outs[i]], [inputs[i]], grad_outputs=[grad_outputs[i]],
                    retain_graph=True, create_graph=True, allow_unused=False)
        J[i, :] = gradient[0]
    return J

q = torch.tensor([1.0, 2.0]).requires_grad_(True)
print(f"q: {q}")
x = torch.cat([2*q, 4*q*q], -1)
print(f"x: {x}")

J1 = jacobian(q, x)
print(f"J1: {J1}")
J2 = jacobian2(q, x)
print(f"J2: {J2}")

The output is:

q: tensor([1., 2.], requires_grad=True)
x: tensor([ 2.,  4.,  4., 16.], grad_fn=<CatBackward>)
J1: tensor([[10., 18.],
        [10., 18.],
        [10., 18.],
        [10., 18.]], grad_fn=<StackBackward>)
J2: tensor([[ 2.,  0.],
        [ 0.,  2.],
        [ 8.,  0.],
        [ 0., 16.]], grad_fn=<CopySlices>)

As you can see, jacobian2 gives the desired result, while jacobian seems to be summing up the rows. Why is that so? Originally I though my mistake is the grad_output variable, but if I use None (since each slice is a “scalar” function), the output is the same.

Hi,

I would say this is caused by the “No free lunch theorem” :smiley:
This is a rather fundamental limitation of AD I’m afraid.

The issue you have here is thatall your outs and inputs are dependent. And so all the gradient “flowing” for one will flow the same way as the ones flowing for another. And so you get the sum.
Since all you inputs are actually the same Tensor, they all get the sum and you get the same result for each of them.

The solution here would be to use different input Tensors for each forward so that they are independent. But then you will do many times the same forward just to be able to do the same thing as the for-loop but in a single backward call.

Thank you very much. Totally makes sense :slight_smile: