Using autograd to calculate taylor expansions

Hi,

so I have a function f: R^n -> R and I want to calculate the taylor expansion around 0 for all order up to total degree d (so 000, 001, …, 111, , (d-1)00, d00). I need to be more efficient, currently have some slow python loop using pytorch autograd. Is it possible to vectorize something here?

Naive code:


def multi_index_iterator(n, order):
    """
    Generate all multi-indices of dimension n with |alpha| <= order.
    """
    for total_order in range(order + 1):
        for alpha in itertools.product(range(total_order + 1), repeat=n):
            if sum(alpha) == total_order:
                yield alpha


def taylor_coefficients(f, x0, order):
    """
    Compute Taylor coefficients for f at x0 up to a given order.
    
    Parameters:
    - f: a function f: R^n -> R (expects a tensor input)
    - x0: a torch tensor of shape (n,) at which to expand (typically zero)
    - order: highest total derivative order
    Returns:
    - A dictionary mapping multi-indices (tuples) to their coefficient.
    """
    # Ensure x0 requires grad
    x0 = x0.clone().detach().requires_grad_(True)
    
    coeffs = {}
    # Loop over all multi-indices with |alpha| <= order.
    for alpha in multi_index_iterator(len(x0), order):
        # Compute the derivative corresponding to alpha.
        # We'll use recursion: differentiate repeatedly according to the multi-index.
        # Define a helper function for recursive differentiation.
        def recursive_derivative(f, x, alpha, depth=0):
            if depth == len(alpha):
                return f(x)
            else:
                # Differentiate repeatedly with respect to x[depth]
                derivative = f(x)
                for _ in range(alpha[depth]):
                    grad = torch.autograd.grad(derivative, x, create_graph=True)[0]
                    derivative = grad[depth]
                # Move to next coordinate
                return recursive_derivative(lambda xx: derivative, x, alpha, depth + 1)
        
        derivative_val = recursive_derivative(f, x0, alpha).detach().item()
        # Divide by the factorial of the multi-index to get the coefficient.
        factorial = math.prod(math.factorial(a) for a in alpha)
        coeffs[alpha] = derivative_val / factorial
    return coeffs