I tried the Jacobian computation by JAX (jax.jacfwd and jax.jacrev) on 1 GPU (NVIDIA Tesla V100)
I defined the same Linear NN for JAX using Flax
class FlaxNet(flax.nn.Module):
def apply(self, x):
s1 = flax.nn.Dense(x, features=100, bias=False)
a1 = flax.nn.relu(s1)
s2 = flax.nn.Dense(a1, features=1000, bias=False)
return s2
Jacobian by JAX
def jax_jacobian(x, model, mode='fwd'):
output = model(x)
if mode == 'fwd':
Jx = jax.jacfwd(model)(x)
else:
Jx = jax.jacrev(model)(x)
return Jx
For simplicity, I just set mini-batch size bs= 1.
bs: 1
n_inputs: 1000
n_outputs: 1000
loop: 100
-------------
jax.jacfwd: 1.959s
jax.jacrev: 1.240s
reverse mode: 2.564s
manual mode: 0.471s
As I’m a beginner of JAX, I might not be fully utilizing the potential of JAX, but the manual mode in PyTorch is still the fastest.
(example code for these runs)