Hi,
I noticed that torch.linalg.solve_triangular accepts matrices with different types.
device = "cuda:0"
m = 1024
n = 10_000
L = torch.randn(m, m, dtype=torch.float64, device=device)
B = torch.randn(m, n, dtype=torch.float32, device=device)
# triangular solve in double precision: 2.74 ms
solve = torch.linalg.solve_triangular(L, B.type(torch.float64), upper=False)
# triangular solve without type conversion: 2.65 ms
solve = torch.linalg.solve_triangular(L, B, upper=False) # this works fine!?
Note that the second approach does not do explicit type conversion and somehow the code works just fine. I don’t think this is documented.
I also observed that the second approach is slightly faster (3% ~ 5% running time reduction), consistently. Indeed, GPU profiling shows that the second approach launches fewer CUDA kernels.
My question is how does the triangular solve work under the hood? Does it convert the matrix B to double precision on the fly? And is it numerically safe to solve triangular systems in mixed types?