I need to compute both the vector-Jacobian product and the Jacobian-vector product at the same time, and then to backprop through both. I have the following code that I have tested and I believe it works correctly:
def vjp(f, x, v, create_graph=True):
x = x.detach().requires_grad_()
y = f(x)
y.backward(v, create_graph=create_graph)
return x.grad
def jvp(f, x, v, create_graph=True):
g = lambda v: vjp(f, x, v, create_graph=True)
return vjp(g, v, v, create_graph=create_graph)
def get_loss(f, x, v):
vjp_val = vjp(f, x, v)
jvp_val = jvp(f, x, v)
return (vjp_val - jvp_val).norm(1)
It is however inefficient, as it effectively computes f(x).backward(v)
twice. Hence I would like to rewrite it in such way that it only does so once. Here is my attempt:
def get_loss_fast(f, x, v):
x = x.detach().requires_grad_()
y = f(x)
y.backward(v, create_graph=True)
vjp_val = x.grad
vjp_val.backward(v, create_graph=True)
jvp_val = x.grad
return (vjp_val - jvp_val).norm(1)
This code always returns zero. In fact, inside the get_loss_fast
function, vjp_val is jvp_val
is True
, which means that the second backward()
does not overwrite the output of the first one.
How can I compute this loss efficiently and correctly?