Hello,
I wrote a subclass for solve_triangular systems, then I tried to use the gradcheck, but it reports False…
Could you help to review this code? Thanks.
class SolveTrianguler(Function):
# sloves A * x = b
def __init__(self, trans=0, lower=True):
super(SolveTrianguler, self).__init__()
# trans=1, transpose the matrix A.T * x = b
self.trans = trans
# lower=False, use data contained in the upper triangular, the default is lower
self.lower = lower
# self.needs_input_grad = (True, False)
def forward(self, matrix, rhs):
x = torch.from_numpy(
solve_triangular(matrix.numpy(), rhs.numpy(),
trans=self.trans, lower=self.lower))
self.save_for_backward(matrix, x)
return x
def backward(self, grad_output):
# grad_matrix = grad_rhs = None
matrix, x = self.saved_tensors
# formula from Giles 2008, 2.3.1
return -matrix.inverse().t().mm(grad_output).mm(torch.t(x)), \
matrix.inverse().t().mm(grad_output)