Hello,
I stumbled on this page of the pytorch doc, and I had a few questions about the use of this method.
First of all, I’m not really comfortable with auto-diff, and I’ve had a hard time understanding the difference between reverse mode AD and forward mode AD. The notable difference that I seem to have understood is that one will be run alongside the forward pass, in order to minimize the numbers of operations used to compute a JVP.
If this understanding is correct, I’d expect the forward mode AD JVP to be faster than the double grad trick, as it will run one loop instead of two.
When I benchmark both methods, the double trick still seems to be running faster than forward mode AD.
Here’s what I ran to evaluate both methods (I’ve read on a github issue that the torch.autograd.functionnal
api does use forward mode AD):
import torch
import torch.nn as nn
from torch.autograd.functional import jvp
def Ju(x, y, u):
w = torch.ones_like(y, requires_grad=True)
return torch.autograd.grad(torch.autograd.grad(y, x, w, create_graph=True), w, u, create_graph=True)[0]
Ju_fast = lambda u : torch.matmul(u, model.weight.T) # The real jacobian of a linear model
model = nn.Linear(20, 20)
input_ = torch.randn(16, 20)
x = input_.clone()
x.requires_grad_()
y = model(x)
u = torch.randn(16, 20)
_, grads_fwAD = jvp(model, (x,), (u,))
grads_2trick = Ju(x, y, u)
grads_fast = Ju_fast(u)
assert torch.isclose(grads_fwAD, grads_2trick).all()
assert torch.isclose(grads_fwAD, grads_fast).all()
%%timeit -n 10 -r 1000
x = input_.clone()
x.requires_grad_()
y = model(x)
Ju(x, y, u)
# Yields 131 µs ± 48.3 µs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
%%timeit -n 10 -r 1000
jvp(model, (input_,), (u,))
# Yields 192 µs ± 94 µs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
%%timeit -n 10 -r 1000
Ju_fast(u)
# Yields 14.6 µs ± 6.26 µs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
From what I understood, the forward mode AD should be computing J @ u
alongside W @ x + b
for the linear model (for elementwise function, it seems that simply computing u^TJ
provides Ju
since J
is diagonal), and, which would result in a very fast computation of the JVP.
Am I mistaked in my undestanding of what’s happening behind pytorch ? Are these results to be expected and why ?
Thank you for your answers !