Computing vector-Jacobian and Jacobian-vector product efficiently

I need to compute both the vector-Jacobian product and the Jacobian-vector product at the same time, and then to backprop through both. I have the following code that I have tested and I believe it works correctly:

def vjp(f, x, v, create_graph=True):
    x = x.detach().requires_grad_()
    y = f(x)
    y.backward(v, create_graph=create_graph)
    return x.grad

def jvp(f, x, v, create_graph=True):
    g = lambda v: vjp(f, x, v, create_graph=True)
    return vjp(g, v, v, create_graph=create_graph)

def get_loss(f, x, v):
    vjp_val = vjp(f, x, v)
    jvp_val = jvp(f, x, v)

    return (vjp_val - jvp_val).norm(1)

It is however inefficient, as it effectively computes f(x).backward(v) twice. Hence I would like to rewrite it in such way that it only does so once. Here is my attempt:

def get_loss_fast(f, x, v):
    x = x.detach().requires_grad_()
    y = f(x)
    y.backward(v, create_graph=True)
    vjp_val = x.grad

    vjp_val.backward(v, create_graph=True)
    jvp_val = x.grad

    return (vjp_val - jvp_val).norm(1)

This code always returns zero. In fact, inside the get_loss_fast function, vjp_val is jvp_val is True, which means that the second backward() does not overwrite the output of the first one.

How can I compute this loss efficiently and correctly?

Hi,

First thing for such tasks, I would advise to use autograd.grad() instead of .backward() as .backward() might create reference cycles.

Some thing like this should work no?

def get_loss_fast(f, x, v):
    x = x.detach().requires_grad_()
    v = v.detach().requires_grad_()
    y = f(x)
    grad_x = autograd.grad(y, x, v, create_graph=True)[0]
    vjp_val = grad_x

    jvp_val = autograd.grad(grad_x, v, v.detach())[0]

    return (vjp_val - jvp_val).norm(1)

note that because of the detach at the beginning, no gradient will flow back to the input x or v if you try to backward the result of that function. Is that something you want?

Thank you, this seems to return correct results. One question though: why did you put v.detach() in the expression for jvp_val? I.e., why would something like this not work correctly?

def get_loss_fast(f, x, v):
    x = x.detach().requires_grad_()
    v = v.detach().requires_grad_()

    vjp_val = autograd.grad(f(x), x, v, create_graph=True)[0]
    jvp_val = autograd.grad(vjp_val, v, v)[0]  # instead of vjp_val, v, v.detach()

    return (vjp_val - jvp_val).norm(1)

Because of the detach at the beginning, no gradient will flow back to the input x or v if you try to backward the result of that function. Is that something you want?

Yes, it is exactly as I intend.

Just because we don’t want to create the graph for that backward, so I was making the grad_outputs not require gradients. But I guess that is not strictly necessary as we do not set create_graph :smiley:

Since I want to backprop through the resulting loss, I (think I) need to set create_graph in the second call to grad(). Otherwise, I am getting the following:

# loss = get_loss_fast(f, x, v)
# loss.backward()

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

where f is an nn.Linear, and x and v are two random vectors. When I specify create_graph=True, there is no such error.

In any case, as far as I understand, the performance penalty for either doing a .detach() or doing a redundant gradient computation is negligible, so there is little point in benchmarking whether this .detach() is useful or not.

A working version of the code for posterity:

def get_loss_fast(f, x, v):
    x = x.detach().requires_grad_()
    v = v.detach().requires_grad_()

    vjp_val = autograd.grad(f(x), x, v, create_graph=True)[0]
    jvp_val = autograd.grad(vjp_val, v, v.detach(), create_graph=True)[0]

    return (vjp_val - jvp_val).norm(1)

All the codes are giving me a result of 0

x=torch.FloatTensor([1,2,3])
v=torch.FloatTensor([1,1,1])
x = x.detach().requires_grad_()
v = v.detach().requires_grad_()
y = x**3 - 6*x
grad_x = torch.autograd.grad(y, x, v, create_graph=True)[0]
print("Grad_x is {0}".format(grad_x))
vjp_val = grad_x
jvp_val = torch.autograd.grad(grad_x, v, v.detach(),create_graph=True)[0]
print("jvp_val is {0}".format(jvp_val))
print((vjp_val - jvp_val).norm(1))

Results

tensor([ 6., 12., 18.], grad_fn=<CloneBackward>)
Grad_x is tensor([-3.,  6., 21.], grad_fn=<AddBackward0>)
jvp_val is tensor([-3.,  6., 21.], grad_fn=<AddBackward0>)
tensor(0., grad_fn=<NormBackward0>)

Well if you want gradients, you might want to double check that.
In particular, you will need to remove the detach at the beginning and make sure they require gradients.
Also remove the detach in the last grad and add a create_grad=True

The following should work:

import torch
from torch import autograd

def f(x):
    return 2 + x**4 + x[0] ** 5

def get_loss_fast(x, v):
    y = f(x)
    grad_x = autograd.grad(y, x, v, create_graph=True)[0]
    vjp_val = grad_x

    jvp_val = autograd.grad(grad_x, v, v, create_graph=True)[0]

    return (vjp_val - jvp_val).norm(1)

inp = (torch.rand(2, dtype=torch.double, requires_grad=True), torch.rand(2, dtype=torch.double, requires_grad=True))
print(autograd.gradcheck(get_loss_fast, inp))

get_loss_fast(*inp).backward()
print(inp[0].grad)
print(inp[1].grad)

In my particular case, I only need the gradients for f, and not for x or v. Thanks for pointing this out though!

Here is a working demo on Colab: https://colab.research.google.com/drive/16RO6uW1L8RS8neC3NUs5wQxfObwHP0i6

Thank you @Tidan
Can you also tell me the mathematical background behind calculating the (vjp_val - jvp_val).norm(1) as loss. Somehow, I am not able to get my head around this

An external link etc would also suffice