Is there any way to compute multiple jvps without repeatly computing function value?

I did some tests on torch.func.vjp and torch.func.jvp. I found that torch.func.vjp only computes function value once, and allows to compute vjp on given primals for multiple times. However, torch.func.jvp is different. It always computes function value together with jvp, even if the primals are unchanged. An example is here:

import torch

class Test(torch.autograd.Function):
    @staticmethod
    def forward(x, a, b, c):
        print("forward")

        return a * x ** 2 + b * x + c

    @staticmethod
    def setup_context(ctx, inputs, output):
        print("setup_context")

        x, a, b, c = inputs
        ctx.saved_vars = (x.item(), a.item(), b.item())

    @staticmethod
    def backward(ctx, grad_y):
        print("backward")

        x, a, b = ctx.saved_vars
        grad_y = grad_y.item()

        grad_x = (2 * a * x + b) * grad_y
        grad_a = (x ** 2) * grad_y
        grad_b = x * grad_y
        grad_c = grad_y

        return torch.tensor(grad_x, dtype=torch.float32), torch.tensor(grad_a, dtype=torch.float32), torch.tensor(grad_b, dtype=torch.float32), torch.tensor(grad_c, dtype=torch.float32)

    @staticmethod
    def jvp(ctx, der_x, der_a, der_b, der_c):
        print("jvp")

        x, a, b = ctx.saved_vars
        der_x = der_x.item()
        der_a = der_a.item()
        der_b = der_b.item()
        der_c = der_c.item()

        der_y = (2 * a * x + b) * der_x + (x ** 2) * der_a + x * der_b + der_c

        return torch.tensor(der_y, dtype=torch.float32)

def test_func(x, a, b, c):
    return Test.apply(x, a, b, c)

x = torch.tensor(2, dtype=torch.float32)
a = torch.tensor(3, dtype=torch.float32)
b = torch.tensor(4, dtype=torch.float32)
c = torch.tensor(5, dtype=torch.float32)

print("Multiple vjp-----------------------")

y, vjp_func = torch.func.vjp(test_func, x, a, b, c)
for i in range(5):
    print(i)
    grad_y = torch.randn([], dtype=torch.float32)
    grad_x, gard_a, gard_b, grad_c = vjp_func(grad_y)

print("Multiple jvp-----------------------")

for i in range(5):
    print(i)
    der_x = torch.randn([], dtype=torch.float32)
    der_a = torch.randn([], dtype=torch.float32)
    der_b = torch.randn([], dtype=torch.float32)
    der_c = torch.randn([], dtype=torch.float32)
    y, der_y = torch.func.jvp(test_func, (x, a, b, c), (der_x, der_a, der_b, der_c))

The output is:

Multiple vjp-----------------------
forward
setup_context
setup_context
0
backward
1
backward
2
backward
3
backward
4
backward
Multiple jvp-----------------------
0
forward
setup_context
setup_context
jvp
1
forward
setup_context
setup_context
jvp
2
forward
setup_context
setup_context
jvp
3
forward
setup_context
setup_context
jvp
4
forward
setup_context
setup_context
jvp

It is showed that, on each call of torch.func.jvp, Test.forward and Test.jvp are both called. In contrast, torch.func.vjp returns a vjp_func object that only calls ā€˜Test.backward’. I wonder if there is a way to call Test.jvp on given primals without calling Test.forward?

Hi! I’m not a forward-mode expert, but I’ll try to help as much as I can.

I wonder if there is a way to call Test.jvp on given primals without calling Test.forward ?

As far as I know, there’s no way to do that. However, I can think of two different workarounds:

vmapped torch.func.jvp

If you want to call torch.func.jvp on multiple different derivatives, you may benefit from using torch.vmap, so that you only have a single (vectorized) forward pass and a single (vectorized) call to Test.jvp.

For instance, you could replace:

for i in range(5):
    der_x = torch.randn([], dtype=torch.float32)
    der_a = torch.randn([], dtype=torch.float32)
    der_b = torch.randn([], dtype=torch.float32)
    der_c = torch.randn([], dtype=torch.float32)
    y, der_y = torch.func.jvp(test_func, (x, a, b, c), (der_x, der_a, der_b, der_c))

by:

der_x = torch.randn([5], dtype=torch.float32)
der_a = torch.randn([5], dtype=torch.float32)
der_b = torch.randn([5], dtype=torch.float32)
der_c = torch.randn([5], dtype=torch.float32)
vmapped_jvp = torch.vmap(torch.func.jvp, in_dims=(None, None, 0))
y, der_y = vmapped_jvp(test_func, (x, a, b, c), (der_x, der_a, der_b, der_c))

For that to work, you also have to add generate_vmap_rule=True as a class attribute of Test, and to change all of your .item() calls by .detach().

In the former, you’ll obtain 5 tensors of shape [] for y and der_y, while in the latter, y and der_y will be tensors of shape [5]. The values will be the same up to the randomness of the generation of der_x, der_a, der_b and der_c. But I guess that in a real use-case, you won’t use random values there anyway.

Here is a slightly modified version to show the equivalence between the two methods:

print("Multiple jvp-----------------------")

torch.manual_seed(0)
der = torch.randn([5, 4])
for i in range(5):
    print(i)
    der_x, der_a, der_b, der_c = der[i, 0], der[i, 1], der[i, 2], der[i, 3]
    y, der_y = torch.func.jvp(test_func, (x, a, b, c), (der_x, der_a, der_b, der_c))
    print(y)
    print(der_y)

print("jvp (vmapped) -----------------------")

torch.manual_seed(0)
der = torch.randn([5, 4])
der_x, der_a, der_b, der_c = der[:, 0], der[:, 1], der[:, 2], der[:, 3]
vmapped_jvp = torch.vmap(torch.func.jvp, in_dims=(None, None, 0))
y, der_y = vmapped_jvp(test_func, (x, a, b, c), (der_x, der_a, der_b, der_c))
print(y)
print(der_y)

Output:

Multiple jvp-----------------------
0
forward
setup_context
setup_context
jvp
tensor(25.)
tensor(-23.5579)
1
forward
setup_context
setup_context
jvp
tensor(25.)
tensor(4.5313)
2
forward
setup_context
setup_context
jvp
tensor(25.)
tensor(10.0121)
3
forward
setup_context
setup_context
jvp
tensor(25.)
tensor(29.6241)
4
forward
setup_context
setup_context
jvp
tensor(25.)
tensor(24.0090)
jvp (vmapped) -----------------------
forward
setup_context
setup_context
jvp
tensor([25., 25., 25., 25., 25.])
tensor([-23.5579,   4.5313,  10.0121,  29.6241,  24.0090])

Here, you can also see that the value of y is always the same. This is normal, because the input value (x, a, b and c) is also fixed. But in the vmapped case, you can set out_dims=(None, 0) so that only one value is returned for instead of repeating the output 5 times for y. I don’t know if that saves computations or memory, though.

The line changes to:

vmapped_jvp = torch.vmap(torch.func.jvp, in_dims=(None, None, 0), out_dims=(None, 0))

and the output changes to:

tensor(25.)  # only one value for `y`
tensor([-23.5579,   4.5313,  10.0121,  29.6241,  24.0090]

Define your own function

The second solution would be to define a function jvp_with_inputs in Test as follows:

@staticmethod
def jvp_with_inputs(x, a, b, der_x, der_a, der_b, der_c):
    x = x.detach()
    a = a.detach()
    b = b.detach()

    der_x = der_x.detach()
    der_a = der_a.detach()
    der_b = der_b.detach()
    der_c = der_c.detach()

    der_y = (2 * a * x + b) * der_x + (x ** 2) * der_a + x * der_b + der_c

    return der_y

Then, you could do:

print("Multiple jvp_with_inputs-----------------------")

torch.manual_seed(0)
y = test_func(x, a, b, c)
der = torch.randn([5, 4])
for i in range(5):
    print(i)
    der_x, der_a, der_b, der_c = der[i, 0], der[i, 1], der[i, 2], der[i, 3]
    der_y = Test.jvp_with_inputs(x, a, b, der_x, der_a, der_b, der_c)
    print(der_y)

Output:

Multiple jvp_with_inputs-----------------------
forward
setup_context
0
tensor(-23.5579)
1
tensor(4.5313)
2
tensor(10.0121)
3
tensor(29.6241)
4
tensor(24.0090)

This may be a pain to implement on a real use-case though (because you would have to implement this function for every autograd function you use), compared to the first solution that should work in most cases (in all autograd functions where vmap is supported). But it may be more efficient. Again, I’m not an expert in this, so take it with a grain of salt, and if anyone knows better, don’t hesitate to correct me!

I hope this helps!

Hi Haoren!

vjp performs reverse-mode (ā€œstandardā€) automatic differentiation. Reverse-mode AD involves
a forward pass during which the ā€œcomputation graphā€ is constructed (and the function value
itself is computed). Then a backward pass computes the gradient of a scalar-valued function
(which might be a linear combination of the elements of a vector-valued function output) by
backpropagating intermediate gradients through the computation graph. Note that the
computation graph built during a single forward pass can be reused by multiple backward
passes (as you might do to compute the full jacobian of a vector-valued function).

A backward pass can be as computationally expensive as the corresponding forward pass
(although it won’t necessarily be, depending on the details of the function). Also, just
computing the function’s value without constructing the computation graph is cheaper than
performing a ā€œfullā€ forward pass that does construct the computation graph.

So in your vjp test you see a single forward pass that computes the function value once,
but you see multiple backward passes, one for each linear combination of output function
values you evaluate. (Your test case is somewhat degenerate. Your function is scalar
valued, so the different ā€œlinear combinations of output function valuesā€ are, rather trivially,
just the single scalar output value multiplied by the random scale factor grad_y.)

In contrast, jvp performs forward-mode AD in which a single pass (just a single forward
pass with no backward pass) is performed. During this forward pass, forward-mode AD
computes the directional derivatives of all elements of a vector-valued output with respect
to the direction specified by the ā€œtangentā€ using ā€œdual variables.ā€ Computing the full jacobian
of a function (of a vector-valued input) requires multiple forward-mode forward passes, one
for a ā€œtangentā€ in the direction of each scalar element of the vector-valued input. The main
cost of a forward-mode forward pass is that of ā€œforward-propagatingā€ those directional
derivatives (without necessarily constructing a computation graph). The cost of computing
the function value itself, while not free, is typically small by comparison and the function
value basically just comes along for the ride.

So in your jvp test you see multiple forward-mode forward passes (analogous to the
multiple reverse-mode backward passes you see in your vjp test) and each forward pass
does, indeed, recompute the function value itself.

No. This is because jvp uses just a forward pass. To compute the directional derivative
for multiple tangents you have to call multiple forward passes (each of which computes the
function value itself). Unlike reverse-mode AD, forward-mode doesn’t ā€œsave any of its workā€
in the analog of a computation graph that could be reused in multiple passes, for example,
to avoid recomputing the function value itself. (But again, computing the function value
itself typically accounts for only a modest part of the total computation cost.)

Best.

K. Frank

1 Like

Thanks a lot for your reply! I’ve tested the method that uses vmap, but it can’t be applied to my real use-case. The reason is, it requires all the computations in Test being implemented with pytorch. In other words, it prevents me from fetching tensors’ data and implementing the computations in CUDA.

The other answer tolds me, it’s impossible to call Test.jvp alone, because Test.forward is required to construct the computation graph. However, I’ve found a way to avoid repeatly computing function value. That is, record the function value in the first call of Test.forward, and directly return the recorded value in the subsequent calls. A shortcoming is, I have to manually delete the recorded value when it’s useless, which may be comlex if Test appears multiple times in the computation graph. Fortunately, in my use case, each of my custom operators only appears once.

1 Like

Thanks a lot for your reply! I’ve used some other AD systems, and most of them implement forward-AD along with function value in one interface (such as the AD module of slang), which makes it impossible to compute jvp alone. However, I found pytorch provides an independent interface torch.autograd.Function.jvp for forward-AD. Thus, at the first time, I thought there might be a way to call it alone.

It’s not tested yet that how much time the computations of function value take in my real use case. However, I’ve found a way to avoid repeatly computing function value. That is, record the function value in the first call of Test.forward, and directly return the recorded value in the subsequent calls. A shortcoming is, I have to manually delete the recorded value when it’s useless, which may be comlex if Test appears multiple times in the computation graph. Fortunately, in my use case, each of my custom operators only appears once. I will decide wether it is neccesary to use this method after efficiency tests.

You may wanna look at torch.library to be able to use your custom operations as if they came from pytorch and in particular at torch.library.register_vmap to be able to register a vmapped version for them. I’ve never used it personally, so I’m not sure, but maybe it will suit your needs.