Forward Mode AD vs double grad trick for jacobian vector product

Hello,

I stumbled on this page of the pytorch doc, and I had a few questions about the use of this method.

First of all, I’m not really comfortable with auto-diff, and I’ve had a hard time understanding the difference between reverse mode AD and forward mode AD. The notable difference that I seem to have understood is that one will be run alongside the forward pass, in order to minimize the numbers of operations used to compute a JVP.

If this understanding is correct, I’d expect the forward mode AD JVP to be faster than the double grad trick, as it will run one loop instead of two.

When I benchmark both methods, the double trick still seems to be running faster than forward mode AD.

Here’s what I ran to evaluate both methods (I’ve read on a github issue that the torch.autograd.functionnal api does use forward mode AD):

import torch
import torch.nn as nn
from torch.autograd.functional import jvp

def Ju(x, y, u):
    w = torch.ones_like(y, requires_grad=True)
    return torch.autograd.grad(torch.autograd.grad(y, x, w, create_graph=True), w, u, create_graph=True)[0]

Ju_fast = lambda u : torch.matmul(u, model.weight.T) # The real jacobian of a linear model

model = nn.Linear(20, 20)
input_ = torch.randn(16, 20)

x = input_.clone()
x.requires_grad_()
y = model(x)
u = torch.randn(16, 20)

_, grads_fwAD = jvp(model, (x,), (u,))
grads_2trick = Ju(x, y, u)
grads_fast = Ju_fast(u)
assert torch.isclose(grads_fwAD, grads_2trick).all()
assert torch.isclose(grads_fwAD, grads_fast).all()
%%timeit -n 10 -r 1000
x = input_.clone()
x.requires_grad_()
y = model(x)
Ju(x, y, u)
# Yields 131 µs ± 48.3 µs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
%%timeit -n 10 -r 1000
jvp(model, (input_,), (u,))
# Yields 192 µs ± 94 µs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
%%timeit -n 10 -r 1000
Ju_fast(u)
# Yields 14.6 µs ± 6.26 µs per loop (mean ± std. dev. of 1000 runs, 10 loops each)

From what I understood, the forward mode AD should be computing J @ u alongside W @ x + b for the linear model (for elementwise function, it seems that simply computing u^TJ provides Ju since J is diagonal), and, which would result in a very fast computation of the JVP.

Am I mistaked in my undestanding of what’s happening behind pytorch ? Are these results to be expected and why ?

Thank you for your answers !

jvp from torch.autograd.functional seems to not use forward AD. If I understand correctly the doc you referred to, forward AD relies on torch.autograd.forward_ad.

My attempt for fwAD version of your implementation would be:

import torch.autograd.forward_ad as fwAD

def fwad_jvp(model, input_): 
    # https://pytorch.org/tutorials/intermediate/forward_ad_usage.html
    params = {name: p for name, p in model.named_parameters()}
    tangents = {name: torch.rand_like(p) for name, p in params.items()}

    with fwAD.dual_level():
        for name, p in params.items():
            delattr(model, name)
            setattr(model, name, fwAD.make_dual(p, tangents[name]))

        out = model(input_)
        jvp = fwAD.unpack_dual(out).tangent
    return jvp

Disclaimer: this is not returning the same values as the ones of your other methods (I haven’t investigated further) - however it is significantly faster (3 to 4 times for me, using torch version 1.12.1). Maybe this could still help?

1 Like

Thanks for your answer !

From your code and looking at the documentation of fwAD.make_dual, I got a function that computs the JVP, and it seems that it does run faster than the double back trick. I also increased the size of the input, as the double back trick might be faster for very small input sizes.

The torch version I’m using for these experiments is 1.12. Here’s the new code:

import torch
import torch.autograd.forward_ad as fwAD
import torch.nn as nn
from torch.autograd.functional import jvp


def fwad_jvp(model, input_, u):
    """
    Given a function `model` whose jacobian is `J`, it allows one to compute the Jacobian-vector product (`jvp`)
    between `J` and a given vector `u` as follows.

    Example::

        >>> with dual_level():
        ...   inp = make_dual(x, u)
        ...   out = model(inp)
        ...   y, jvp = unpack_dual(out)
    """
    with fwAD.dual_level():
        inp = fwAD.make_dual(x, u)
        out = model(inp)
        y, jvp = fwAD.unpack_dual(out)
    return jvp


def Ju(x, y, u):
    w = torch.ones_like(y, requires_grad=True)
    return torch.autograd.grad(torch.autograd.grad(y, x, w, create_graph=True), w, u, create_graph=True)[0]


Ju_fast = lambda u: torch.matmul(u, model.weight.T)  # The real jacobian of a linear model

model = nn.Linear(100, 100)
input_ = torch.randn(32, 100)

x = input_.clone()
x.requires_grad_()
y = model(x)
u = torch.randn(32, 100)

_, grads_fwAD = jvp(model, (input_,), (u,))
grads_2trick = Ju(x, y, u)
grads_fast = Ju_fast(u)
grads_fwAD_2 = fwad_jvp(model, input_, u)
# Make sure that all the gradients are equal
assert torch.isclose(grads_fwAD, grads_2trick).all()
assert torch.isclose(grads_fwAD, grads_fast).all()
assert torch.isclose(grads_fwAD, grads_fwAD_2).all()
%%timeit -n 100 -r 1000
x = input_.clone()
x.requires_grad_()
y = model(x)
Ju(x, y, u)
# => 282 µs ± 436 µs per loop (mean ± std. dev. of 1000 runs, 100 loops each)
%%timeit -n 100 -r 1000
jvp(model, (input_,), (u,))
# => 344 µs ± 542 µs per loop (mean ± std. dev. of 1000 runs, 100 loops each)
%%timeit -n 100 -r 1000
fwad_jvp(model, input_, u)
# => 180 µs ± 356 µs per loop (mean ± std. dev. of 1000 runs, 100 loops each)
%%timeit -n 100 -r 1000
Ju_fast(u)
# => 22.8 µs ± 125 µs per loop (mean ± std. dev. of 1000 runs, 100 loops each)

I’d say that the JVP in pytorch maybe does not implement the fwAD, even though I’ve read it on the github issue related to this update. It’s much slower than a direct matrix product, but I guess there’s no way to get that speed with all the necessary overhead created by the coding base. the fwad_jvp does run faster than a direct double back trick.

Cheers.

1 Like