Error when implementing custom autograd.Function

I am currently learning how to implement custom autograd.Functions, and for the sake of practice I try to re-write the backwards pass of the torch.linalg.cholesky-function. The code in the implemented backwards pass is intended to be the same as the cholesky_backward function at line 1939 in FunctionsManual.cpp. The only difference is that i have written it in Python, and attempted to avoid in-place operations.

However, the code does not pass the gradcheck, which raises as error stating that there is a “Jacobian mismatch”. Can anyone tell me what I am doing wrong?

import torch
class CustomCholFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A: torch.Tensor) -> torch.Tensor:
        A_cpu = A.cpu().double()

        # Perform Cholesky decomposition
        L = torch.linalg.cholesky(A_cpu, upper=False)

        # Save the lower triangular matrix for backward pass
        L_cpu = L.cpu()
        ctx.L_cpu = L_cpu

        return L_cpu.to(A.device, dtype=A.dtype)

    @staticmethod
    def backward(ctx, gL: torch.Tensor) -> torch.Tensor:
        gL_cpu = gL.cpu()
        L_cpu = ctx.L_cpu.clone()

        # This code is meant to be the same as the backwards pass in the
        # autograd implementation of torch.linalg.cholesky
        upper = False  # Currently assume lower triangular
        L_ = L_cpu.adjoint() if upper else L_cpu
        gL_ = gL_cpu.adjoint() if upper else gL_cpu

        # Compute gradient of A using out-of-place operations
        gA = torch.matmul(L_.adjoint(), gL_).tril()

        gA = 0.5 * (gA + gA.tril(-1).adjoint())
        gA = torch.linalg.solve_triangular(L_.adjoint(), gA, upper=True, left=True)
        gA = torch.linalg.solve_triangular(L_, gA, upper=False, left=False)

        return gA.to(gL.device)

# Make SPD matrix
A = torch.randn(5, 5, dtype=torch.double)
A = A @ A.t() + 0.1 * torch.eye(5, dtype=torch.double)
A.requires_grad = True

# Works as intended
CustomCholFunc.apply(A)

# Raises an error
torch.autograd.gradcheck(CustomCholFunc.apply, A)