Performance of Highly Nested Tensor Operations

I’ve implemented a numerical integration library for Ordinary Differential Equations called DESolver.

For this library, I implemented a PyTorch backend in order to permit differentiating the results of the numerical integration with respect to any tensor.

The first thing I’ve noticed is that the pytorch backend is, on average, 5x slower than the numpy backend. This could be due to the fact that my numpy distribution is MKL optimized so it’s not my main concern.

The primary thing I’ve noticed is that differentiating the results is very very slow. Running the following code where it takes ~5s to numerically integrate the solution, computing the jacobian of the results wrt the initial conditions takes ~15s. This is surprisingly slow and I am wondering if I am doing something odd or if autograd is just not very efficient for very deep graphs.

Further details:

(In this case, the numerical integration takes ~2000 steps thus the final state is at least 2000 operations removed from the initial state.)

import os
os.environ['DES_BACKEND'] = 'torch'

import desolver as de
import desolver.backend as D

D.set_float_fmt('float64')

import torch

torch.set_printoptions(precision=17)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

yi = torch.tensor([1.0, 1.0], requires_grad=True, device=device)
m  = torch.tensor([[0.0, 1.0],[-20.0, 0.0]], device=device)

def df(t, x):
    return torch.mv(m, x)

a = de.OdeSystem(df, yi1, t=(0, 2*D.pi), dt=0.1, rtol=1e-12, atol=1e-12)
a.set_method("RK45CK")
a.integrate(eta=True)

print(D.jacobian(a[-1][1], a[0][1]))

The jacobian function I am using:

def jacobian(out_tensor, in_tensor, batch_mode=False, nu=1, create_graph=True):
    """Computes the derivative of an output tensor wrt an input tensor.

    Computes the full nu-th order derivative for the output tensor wrt an input tensor. 
    For nu = 1, this is the Jacobian, for nu = 2, this is the Hessian, etc.
    The computation scales with the number of output values, ie. out_tensor.numel(), thus it will
    become quite slow for very large tensors.
    
    The batched computation assumes that the first dimension is the batch dimension and computes the 
    derivative for all the batch elements. The batches are computed in parallel thus for reasonable batch
    sizes the computation should scale as out_tensor.numel() / out_tensor.shape[0].

    Parameters
    ----------
    out_tensor : torch.tensor
        The function whose derivative is to be computed
    in_tensor  : torch.tensor
        The input wrt which the derivative is to be computed
    batch_mode : bool
        Determines if the first dimension is to be treated as a batch dimension or not
    nu : int
        Order of the derivative to be computed
    create_graph : bool
        To keep the computational graph after the jacobian is computed. This is useful if you intend to
        compute further derivatives on the derivative, e.g. for gradient descent.

    Returns
    -------
    torch.tensor
        The derivative tensor of out_tensor wrt in_tensor

    Raises
    ------
    ValueError
        If nu < 0 as that is not a valid derivative order.

    See Also
    --------
    torch.autograd.grad : The base function through which gradients are computed

    Examples
    --------
    ```python
    >>> b   = torch.tensor( [0.0, 1.0], dtype=torch.float64, requires_grad=True)
    >>> mat = torch.tensor([[0.0, 1.0], [-5.0, 0.0]], dtype=torch.float64, requires_grad=True)
    >>> k   = mat@b
    >>> jacobian(k, b, nu=1)
    tensor([[ 0.,  1.],
            [-5.,  0.]], dtype=torch.float64)
    >>> jacobian(k, mat, nu=1)
    tensor([[[0., 1.],
             [0., 0.]],

            [[0., 0.],
             [0., 1.]]], dtype=torch.float64, grad_fn=<AsStridedBackward>)
    ```
    """
    if nu < 0:
        raise ValueError("nu cannot be less than zero! That's not a derivative...")
    if nu == 0:
        return out_tensor
    if out_tensor.requires_grad == False:
        if batch_mode:
            temp = torch.zeros(out_tensor.shape + in_tensor.shape[1:], dtype=in_tensor.dtype, device=in_tensor.device, requires_grad=False)
        else:
            temp = torch.zeros(out_tensor.shape + in_tensor.shape, dtype=in_tensor.dtype, device=in_tensor.device, requires_grad=False)
    else:
        if batch_mode:
            outputs_view = out_tensor.view(out_tensor.shape[0], -1)
            batch_one    = torch.ones_like(outputs_view[:, 0])
            temp = [
                torch.autograd.grad(
                    outputs_view[:, j], 
                    in_tensor,
                    grad_outputs=batch_one,
                    allow_unused=True,
                    retain_graph=True,
                    create_graph=create_graph if nu==1 else True
                )[0] for j in range(outputs_view.shape[1])]
            final_shape = out_tensor.shape + in_tensor.shape[1:]
        else:
            outputs_view = out_tensor.view(-1)
            temp = [torch.autograd.grad(
                outputs_view[i], 
                in_tensor,
                allow_unused=True,
                retain_graph=True,
                create_graph=create_graph if nu==1 else True,
            )[0] for i in range(outputs_view.shape[0])]
            final_shape = out_tensor.shape + in_tensor.shape
        temp = torch.stack([
            i if i is not None else torch.zeros_like(in_tensor) for i in temp
        ])
        temp = temp.view(final_shape)
    if nu > 1:
        temp = jacobian(temp, in_tensor, create_graph=create_graph, nu=nu-1, batch_mode=batch_mode)
    return temp