I am currently learning how to implement custom autograd.Functions
, and for the sake of practice I try to re-write the backwards pass of the torch.linalg.cholesky
-function. The code in the implemented backwards pass is intended to be the same as the cholesky_backward
function at line 1939 in FunctionsManual.cpp. The only difference is that i have written it in Python, and attempted to avoid in-place operations.
However, the code does not pass the gradcheck
, which raises as error stating that there is a “Jacobian mismatch”. Can anyone tell me what I am doing wrong?
import torch
class CustomCholFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, A: torch.Tensor) -> torch.Tensor:
A_cpu = A.cpu().double()
# Perform Cholesky decomposition
L = torch.linalg.cholesky(A_cpu, upper=False)
# Save the lower triangular matrix for backward pass
L_cpu = L.cpu()
ctx.L_cpu = L_cpu
return L_cpu.to(A.device, dtype=A.dtype)
@staticmethod
def backward(ctx, gL: torch.Tensor) -> torch.Tensor:
gL_cpu = gL.cpu()
L_cpu = ctx.L_cpu.clone()
# This code is meant to be the same as the backwards pass in the
# autograd implementation of torch.linalg.cholesky
upper = False # Currently assume lower triangular
L_ = L_cpu.adjoint() if upper else L_cpu
gL_ = gL_cpu.adjoint() if upper else gL_cpu
# Compute gradient of A using out-of-place operations
gA = torch.matmul(L_.adjoint(), gL_).tril()
gA = 0.5 * (gA + gA.tril(-1).adjoint())
gA = torch.linalg.solve_triangular(L_.adjoint(), gA, upper=True, left=True)
gA = torch.linalg.solve_triangular(L_, gA, upper=False, left=False)
return gA.to(gL.device)
# Make SPD matrix
A = torch.randn(5, 5, dtype=torch.double)
A = A @ A.t() + 0.1 * torch.eye(5, dtype=torch.double)
A.requires_grad = True
# Works as intended
CustomCholFunc.apply(A)
# Raises an error
torch.autograd.gradcheck(CustomCholFunc.apply, A)