How to cast/convert a JAX function into a Pytorch autograd differentiable function

A way to convert a pure JAX function into an autograd differentiable PyTorch function that allows the function to be used in backward() and autograd.grad() calls while supporting higher orders of differentiability.

import torch as th
import jax
import jax.numpy as jnp

"""
Author: Josue N Rivera
"""

def t2j(tensor: th.Tensor) -> jax.Array:
    """Zero-copy PyTorch tensor -> JAX array via DLPack (detached)."""
    return jnp.from_dlpack(tensor.detach().contiguous())

def j2t(array: jax.Array) -> th.Tensor:
    """Zero-copy JAX array -> PyTorch tensor via DLPack."""
    return th.from_dlpack(array)

def j2t_fun(fn):
    r"""
    Wrap a pure JAX function (N array inputs -> 1 array output) as a
    PyTorch-autograd-differentiable callable.

    Gradients are evaluated through ``jax.vjp`` and bridged with DLPack. The
    backward pass is itself built from :func:`j2t_fun` wrappers, so the
    callable supports differentiation to arbitrary order (e.g. ``dfdxx`` via
    repeated :func:`torch.autograd.grad`).

    Note: This can be used directly as a decorator for JAX functions.
    """

    def wrapped(*args: th.Tensor) -> th.Tensor:

        class JaxFn(th.autograd.Function):

            @staticmethod
            def forward(ctx, *tensors):
                ctx.save_for_backward(*tensors)
                ctx.n = len(tensors)
                return j2t(fn(*[t2j(t) for t in tensors]))

            @staticmethod
            def backward(ctx, grad):
                tensors, n = ctx.saved_tensors, ctx.n

                grads = []
                for i in range(n):
                    def vjp_i(*inputs_and_cotangent, i=i):
                        inputs = inputs_and_cotangent[:n]
                        cotangent = inputs_and_cotangent[n]
                        _, vjp = jax.vjp(fn, *inputs)
                        return vjp(cotangent)[i]

                    grads.append(j2t_fun(vjp_i)(*tensors, grad))

                return tuple(grads)

        return JaxFn.apply(*args)

    return wrapped

if __name__ == "__main__":
    
    @j2t_fun
    @jax.jit
    def afun(x: jax.Array, u: jax.Array) -> jax.Array:
      return jnp.sin(x**2 + 2*u*x + u**2)
    
    xs = th.rand(10, 1).requires_grad_()
    us = th.rand(10, 1).requires_grad_()
    
    # Compile
    afun(xs, us)
    
    zs = afun(xs, us)
    print("zs shape: ", xs.shape)

    # Autograd grad
    xs_grad = th.autograd.grad(zs.sum(), xs, create_graph=True)[0]
    print("xs_grad shape: ", xs_grad.shape)

    # Autograd backwards
    zs.sum().backward()
    print("xs.grad shape: ", xs.grad.shape)
    print("us.grad shape: ", us.grad.shape)

    assert th.allclose(xs.grad, xs_grad)

This can be useful for those trying to use Mujoco MJX differentiable simulator to train RL policies and neural network-based controllers. You can find a live snippet of the code at: PyTorch autograd differentiable JAX functions ยท GitHub.