I’ve noticed that certain PyTorch functions run slower when I make calls to scipy.linalg.solve_triangular
.
I have managed to reproduce this issue with the following script:
import time
import numpy as np
import torch
import scipy.linalg
def torch_fn(input_size):
with torch.no_grad():
a = torch.randn(input_size, 4, 4, dtype=torch.float64)
b = torch.randn(input_size, 4, 5, dtype=torch.float64)
begin = time.perf_counter()
c = torch.bmm(a, b)
return time.perf_counter() - begin
def scipy_fn(input_size):
a = np.asarray([[1, 0], [1, 1]], dtype=np.float64)
b = np.random.randn(input_size, 2).astype(np.float64)
begin = time.perf_counter()
c = scipy.linalg.solve_triangular(a, b.T, lower=True).T
return time.perf_counter() - begin
def avg_time(torch_input_size, scipy_input_size):
tot = 0
for _ in range(100):
tot += torch_fn(torch_input_size)
if scipy_input_size > 0:
scipy_fn(scipy_input_size)
return tot / 100
if __name__ == '__main__':
print(avg_time(5, 0))
print(avg_time(5, 3))
print(avg_time(5, 75))
print(avg_time(30, 0))
print(avg_time(30, 3))
print(avg_time(30, 75))
The output I get is
0.0002978778630495071
8.901087567210197e-06
0.006054648053832352
0.0008917823247611523
8.631939999759198e-06
0.004081738111563027
So it seems like making a scipy.linalg.solve_triangular
call with sufficiently large inputs results in torch.bmm
running much slower. The runtime difference may not seem too significant, but in our codebase, we call other torch functions and we observe much larger differences (e.g. 30ms vs. < 0.5ms if we call/do not call scipy).
I suspect the issue may have to do with the underlying linear algebra library that scipy calls (in the above case, the trtrs
function in lapack/blas). Beyond this, I can’t explain why this runtime discrepancy occurs.
I’m wondering if someone might have an explanation for this?