I did some tests on torch.func.vjp and torch.func.jvp. I found that torch.func.vjp only computes function value once, and allows to compute vjp on given primals for multiple times. However, torch.func.jvp is different. It always computes function value together with jvp, even if the primals are unchanged. An example is here:
import torch
class Test(torch.autograd.Function):
@staticmethod
def forward(x, a, b, c):
print("forward")
return a * x ** 2 + b * x + c
@staticmethod
def setup_context(ctx, inputs, output):
print("setup_context")
x, a, b, c = inputs
ctx.saved_vars = (x.item(), a.item(), b.item())
@staticmethod
def backward(ctx, grad_y):
print("backward")
x, a, b = ctx.saved_vars
grad_y = grad_y.item()
grad_x = (2 * a * x + b) * grad_y
grad_a = (x ** 2) * grad_y
grad_b = x * grad_y
grad_c = grad_y
return torch.tensor(grad_x, dtype=torch.float32), torch.tensor(grad_a, dtype=torch.float32), torch.tensor(grad_b, dtype=torch.float32), torch.tensor(grad_c, dtype=torch.float32)
@staticmethod
def jvp(ctx, der_x, der_a, der_b, der_c):
print("jvp")
x, a, b = ctx.saved_vars
der_x = der_x.item()
der_a = der_a.item()
der_b = der_b.item()
der_c = der_c.item()
der_y = (2 * a * x + b) * der_x + (x ** 2) * der_a + x * der_b + der_c
return torch.tensor(der_y, dtype=torch.float32)
def test_func(x, a, b, c):
return Test.apply(x, a, b, c)
x = torch.tensor(2, dtype=torch.float32)
a = torch.tensor(3, dtype=torch.float32)
b = torch.tensor(4, dtype=torch.float32)
c = torch.tensor(5, dtype=torch.float32)
print("Multiple vjp-----------------------")
y, vjp_func = torch.func.vjp(test_func, x, a, b, c)
for i in range(5):
print(i)
grad_y = torch.randn([], dtype=torch.float32)
grad_x, gard_a, gard_b, grad_c = vjp_func(grad_y)
print("Multiple jvp-----------------------")
for i in range(5):
print(i)
der_x = torch.randn([], dtype=torch.float32)
der_a = torch.randn([], dtype=torch.float32)
der_b = torch.randn([], dtype=torch.float32)
der_c = torch.randn([], dtype=torch.float32)
y, der_y = torch.func.jvp(test_func, (x, a, b, c), (der_x, der_a, der_b, der_c))
The output is:
Multiple vjp-----------------------
forward
setup_context
setup_context
0
backward
1
backward
2
backward
3
backward
4
backward
Multiple jvp-----------------------
0
forward
setup_context
setup_context
jvp
1
forward
setup_context
setup_context
jvp
2
forward
setup_context
setup_context
jvp
3
forward
setup_context
setup_context
jvp
4
forward
setup_context
setup_context
jvp
It is showed that, on each call of torch.func.jvp, Test.forward and Test.jvp are both called. In contrast, torch.func.vjp returns a vjp_func object that only calls āTest.backwardā. I wonder if there is a way to call Test.jvp on given primals without calling Test.forward?