Torch function runtime dependent on scipy call?

I’ve noticed that certain PyTorch functions run slower when I make calls to scipy.linalg.solve_triangular.

I have managed to reproduce this issue with the following script:

import time

import numpy as np
import torch
import scipy.linalg

def torch_fn(input_size):
    with torch.no_grad():
        a = torch.randn(input_size, 4, 4, dtype=torch.float64)
        b = torch.randn(input_size, 4, 5, dtype=torch.float64)
        begin = time.perf_counter()
        c = torch.bmm(a, b)
        return time.perf_counter() - begin


def scipy_fn(input_size):
    a = np.asarray([[1, 0], [1, 1]], dtype=np.float64)
    b = np.random.randn(input_size, 2).astype(np.float64)
    begin = time.perf_counter()
    c = scipy.linalg.solve_triangular(a, b.T, lower=True).T
    return time.perf_counter() - begin


def avg_time(torch_input_size, scipy_input_size):
    tot = 0
    for _ in range(100):
        tot += torch_fn(torch_input_size)
        if scipy_input_size > 0:
            scipy_fn(scipy_input_size)
    return tot / 100


if __name__ == '__main__':
    print(avg_time(5, 0))
    print(avg_time(5, 3))
    print(avg_time(5, 75))
    print(avg_time(30, 0))
    print(avg_time(30, 3))
    print(avg_time(30, 75))

The output I get is

0.0002978778630495071
8.901087567210197e-06
0.006054648053832352
0.0008917823247611523
8.631939999759198e-06
0.004081738111563027

So it seems like making a scipy.linalg.solve_triangular call with sufficiently large inputs results in torch.bmm running much slower. The runtime difference may not seem too significant, but in our codebase, we call other torch functions and we observe much larger differences (e.g. 30ms vs. < 0.5ms if we call/do not call scipy).

I suspect the issue may have to do with the underlying linear algebra library that scipy calls (in the above case, the trtrs function in lapack/blas). Beyond this, I can’t explain why this runtime discrepancy occurs.

I’m wondering if someone might have an explanation for this?

Another interesting thing is that the order also seems to matter, as the results are different when this is executed (at least in my setup):

if __name__ == '__main__':
    print(avg_time(30, 75))
    print(avg_time(30, 3))
    print(avg_time(30, 0))
    print(avg_time(5, 75))
    print(avg_time(5, 3))
    print(avg_time(5, 0))

Would you mind creating an issue on GitHub so that we could track and check it?