BLAS performance on macOS vs Linux vs Lua Torch

I installed pytorch on several machines, from source and from conda and I am getting different execution times for matrix multiplication. All installs are with Anaconda 4.3.0, python 3.6. However I can’t figure out if pytorch is using MKL or OpenBLAS or other backend. Right now the macOS install is the fastest despite the machine having the slowest CPU.

The reason I ran these tests is because I noticed a severe slowdown (~10 times slower) of a multiprocessing RL algo I am working on when executed on the Linux machines.

On the Linux machines torch seems to be using only a single thread when doing the multiplication, as opposed to macOS. Even though torch.get_num_threads() return the correct no of threads on each system.

###Results:

macOS: Sierra, CPU: Intel i7-4870HQ (8) @ 2.50GHz, 16GB RAM, GeForce GT 750M. Installed from sources.

Allocation:   5.921
Torch Blas:   7.277
Numpy Blas:   7.841
Torch cuBlas: 0.205

Ubuntu 16.10, CPU: Intel i7-4720HQ (8) @ 3.60GHz, 16GB RAM, GeForce GTX 960M. Installed from sources.

Allocation:   4.030
Torch Blas:   21.112
Numpy Blas:   21.82
Torch cuBlas: 0.121

CentOS 7.2, CPU: Intel Xeon E5-2640v4 (40) @ 2.40GHz, 16GB RAM, Titan X. Installed both from source and with conda. Also ran the test with python 3.5 and pytorch built from sources.

Allocation:   4.557
Torch Blas:   19.646
Numpy Blas:   20.155
Torch cuBlas: 0.155

Finally, this is the output of np.__config__.show() on all the machines:

openblas_lapack_info:
  NOT AVAILABLE
lapack_opt_info:
    define_macros = [('SCIPY_MKL_H', None), ('HAVE_CBLAS', None)]
    include_dirs = ['/opt/anaconda3/include']
    library_dirs = ['/opt/anaconda3/lib']
    libraries = ['mkl_intel_lp64', 'mkl_intel_thread', 'mkl_core', 'iomp5', 'pthread']
blas_mkl_info:
    ...
blas_opt_info:
    ...
lapack_mkl_info:
    ...

The code I am using:

import time
import torch
import numpy
torch.set_default_tensor_type("torch.FloatTensor")

w = 5000
h = 40000
is_cuda = torch.cuda.is_available()
start = time.time()

a = torch.rand(w,h)
b = torch.rand(h,w)
a_np = a.numpy()
b_np = b.numpy()
if is_cuda:
    a_cu = a.cuda()
    b_cu = b.cuda()

allocation = time.time()
print("Allocation ", allocation - start)

c = a.mm(b)
th_blas = time.time()
print("Torch Blas ", th_blas - allocation)

c = a_np.dot(b_np)
np_blas = time.time()
print("Numpy Blas ", np_blas - th_blas)

if is_cuda:
    c = a_cu.mm(b_cu)
    cu_blas = time.time()
    print("Torch cuBlas ", cu_blas - np_blas)

print("Final", time.time() - start)

edit: For comparison here are the results of the same script on Lua Torch on the last machine from above:

Allocation:  	4.426
Torch Blas: 	2.777
Torch cuBlas: 	0.097

At this point I am more inclined to believe my linux pytorch installs are using a BLAS fallback. Hoping this isn’t Python’s overhead… :frowning:

1 Like

Can you try exporting MKL_NUM_THREADS to match the number of cores you have?

1 Like

Did that, it didn’t help. Is the variable relevant for runtime only? Or I should also try to recompile torch with it exported?

MKL seems properly configured:

>> mkl.get_max_threads()
>> 20

I also installed today anaconda2 with python 2.7 and pytorch (from conda) on the ubuntu laptop described above. I got the same figures. I could reproduce this on CentOS, Ubuntu, three four different machines, python 3.6, 3.5 and 2.7, with pytorch installed from source and from conda.

Can someone else run the script above and report the numbers?

I’m seeing similar behavior. I ran your script from both PyTorch built from source and the Conda installation:

built from source:

Allocation  6.476427316665649
Torch Blas  4.414772272109985
Numpy Blas  2.665677547454834
Torch cuBlas  0.14812421798706055
Final 13.705262184143066

conda:

Allocation  5.521166086196899
Torch Blas  39.35049605369568
Numpy Blas  39.40145969390869
Final 84.42150139808655

It looks like something is wrong with Conda.

A minor note: Your script only measures cuBlas launch time. Not execution time. You need a torch.cuda.synchronize() call to measure execution time

1 Like

Thanks for taking the time to check this. Also for the tip on benchmarking cuda.

I’ll try to build it again altghough I did this several times with same result.

Can you paste the full log from python setup.py install into a gist? Maybe your MKL isn’t picked up.

I won’t be able to post the log until later today, however I looked specifically for messages related to MKL before starting the thread and it was picking MKL headers/objects from anaconda and also passing BLAS-related tests for operations such as gemm. I’ll post the log as soon as I can nevertheless.

@apaszke Ok, fresh install, same behaviour, here is the full log.

Ok, MKL is found. We’ll have to look into it. Thanks for the report.

Let me know if there’s anything else I can look into, especially if you can’t reproduce this situation.

What fixed things for me was adding “iomp5” to FindMKL.cmake:

diff --git a/torch/lib/TH/cmake/FindMKL.cmake b/torch/lib/TH/cmake/FindMKL.cmake
index e68ae6a..7c9325a 100644
--- a/torch/lib/TH/cmake/FindMKL.cmake
+++ b/torch/lib/TH/cmake/FindMKL.cmake
@@ -50,7 +50,7 @@ ENDIF ("${SIZE_OF_VOIDP}" EQUAL 8)
 IF(CMAKE_COMPILER_IS_GNUCC)
   SET(mklthreads "mkl_gnu_thread" "mkl_intel_thread")
   SET(mklifaces  "gf" "intel")
-  SET(mklrtls)
+  SET(mklrtls "iomp5")
 ELSE(CMAKE_COMPILER_IS_GNUCC)
   SET(mklthreads "mkl_intel_thread")
   SET(mklifaces  "intel")

otherwise mkl_sequential library is used (and your log shows this), but I don’t know nearly enough about compilers and threading libraries interactions to know if it is a robust solution. @apaszke, @colesbury, I can submit a PR if you think that’s Ok.

3 Likes

@ngimel yes, please send a PR. It would be even simpler for us if you could send it to torch/torch7 directly. Thanks!

@ngimel @apaszke I can confirm this fixes the issue indeed. Thank you for all the support, much appreciated!