Efficient computation with multiple grad_output's in autograd.grad

I tried the Jacobian computation by JAX (jax.jacfwd and jax.jacrev) on 1 GPU (NVIDIA Tesla V100)

I defined the same Linear NN for JAX using Flax

class FlaxNet(flax.nn.Module):

    def apply(self, x):
        s1 = flax.nn.Dense(x, features=100, bias=False)
        a1 = flax.nn.relu(s1)
        s2 = flax.nn.Dense(a1, features=1000, bias=False)

        return s2

Jacobian by JAX

def jax_jacobian(x, model, mode='fwd'):
    output = model(x)
    if mode == 'fwd':
        Jx = jax.jacfwd(model)(x)
    else:
        Jx = jax.jacrev(model)(x)

    return Jx

For simplicity, I just set mini-batch size bs= 1.

bs: 1
n_inputs: 1000
n_outputs: 1000
loop: 100
-------------
jax.jacfwd: 1.959s
jax.jacrev: 1.240s
reverse mode: 2.564s
manual mode: 0.471s

As I’m a beginner of JAX, I might not be fully utilizing the potential of JAX, but the manual mode in PyTorch is still the fastest.
(example code for these runs)

1 Like