A way to convert a pure JAX function into an autograd differentiable PyTorch function that allows the function to be used in backward() and autograd.grad() calls while supporting higher orders of differentiability.
import torch as th
import jax
import jax.numpy as jnp
"""
Author: Josue N Rivera
"""
def t2j(tensor: th.Tensor) -> jax.Array:
"""Zero-copy PyTorch tensor -> JAX array via DLPack (detached)."""
return jnp.from_dlpack(tensor.detach().contiguous())
def j2t(array: jax.Array) -> th.Tensor:
"""Zero-copy JAX array -> PyTorch tensor via DLPack."""
return th.from_dlpack(array)
def j2t_fun(fn):
r"""
Wrap a pure JAX function (N array inputs -> 1 array output) as a
PyTorch-autograd-differentiable callable.
Gradients are evaluated through ``jax.vjp`` and bridged with DLPack. The
backward pass is itself built from :func:`j2t_fun` wrappers, so the
callable supports differentiation to arbitrary order (e.g. ``dfdxx`` via
repeated :func:`torch.autograd.grad`).
Note: This can be used directly as a decorator for JAX functions.
"""
def wrapped(*args: th.Tensor) -> th.Tensor:
class JaxFn(th.autograd.Function):
@staticmethod
def forward(ctx, *tensors):
ctx.save_for_backward(*tensors)
ctx.n = len(tensors)
return j2t(fn(*[t2j(t) for t in tensors]))
@staticmethod
def backward(ctx, grad):
tensors, n = ctx.saved_tensors, ctx.n
grads = []
for i in range(n):
def vjp_i(*inputs_and_cotangent, i=i):
inputs = inputs_and_cotangent[:n]
cotangent = inputs_and_cotangent[n]
_, vjp = jax.vjp(fn, *inputs)
return vjp(cotangent)[i]
grads.append(j2t_fun(vjp_i)(*tensors, grad))
return tuple(grads)
return JaxFn.apply(*args)
return wrapped
if __name__ == "__main__":
@j2t_fun
@jax.jit
def afun(x: jax.Array, u: jax.Array) -> jax.Array:
return jnp.sin(x**2 + 2*u*x + u**2)
xs = th.rand(10, 1).requires_grad_()
us = th.rand(10, 1).requires_grad_()
# Compile
afun(xs, us)
zs = afun(xs, us)
print("zs shape: ", xs.shape)
# Autograd grad
xs_grad = th.autograd.grad(zs.sum(), xs, create_graph=True)[0]
print("xs_grad shape: ", xs_grad.shape)
# Autograd backwards
zs.sum().backward()
print("xs.grad shape: ", xs.grad.shape)
print("us.grad shape: ", us.grad.shape)
assert th.allclose(xs.grad, xs_grad)
This can be useful for those trying to use Mujoco MJX differentiable simulator to train RL policies and neural network-based controllers. You can find a live snippet of the code at: PyTorch autograd differentiable JAX functions ยท GitHub.