Going from 0.4.0 to 1.0.0 changes code runtime from 0:00:00.35 to 0:08:45.63

I recently installed pytorch 1.0.0 using conda. I think when I installed 0.4.0 it was from source. Using the exact same code, a single batch with a backwards pass went from taking less than a second to about 9 minutes (running on the CPU). Having run torch.utils.bottleneck, I was told that it was almost entirely the backwards pass which was evident trying to step through the code with a debugger.

I imagine I should probably install pytorch 1.0.0 from source, as the issue will likely have stemmed from not using the correct linear algebra package or something like that. I’m using a Mac, and quite frankly, every time I’ve tried to do this it’s basically entailed spending hours on StackOverflow trying to debug the installation issues. I just want to install 1.0.0, without GPU support, without having to lose a day to googling obscure cmake errors.

I tried following the instructions on https://github.com/pytorch/pytorch and got:

CMake Error: Error: generator : Ninja
Does not match the generator used previously: Unix Makefiles
Either remove the CMakeCache.txt file and CMakeFiles directory or choose a different binary directory.
Traceback (most recent call last):
  File "setup.py", line 734, in <module>
    build_deps()
  File "setup.py", line 281, in build_deps
    build_dir='build')
  File "/Users/askates/Documents/GitRepos/pytorch/tools/build_pytorch_libs.py", line 232, in build_caffe2
    my_env)
  File "/Users/askates/Documents/GitRepos/pytorch/tools/build_pytorch_libs.py", line 213, in run_cmake
    check_call(cmake_args, cwd=build_dir, env=my_env)
  File "/Users/askates/anaconda3/envs/pytorch-1.0/lib/python3.6/subprocess.py", line 291, in check_call
    raise CalledProcessError(retcode, cmd)

Any suggestions as to what could be causing this problem specifically? Thanks for any help you can spare!

1 Like

Could you try to uninstall all previous PyTorch builds and clean the build environment?

conda uninstall pytorch
pip uninstall torch
pip uninstall torch
python setup.py clean  # in your pytorch dir

EDIT: I would also recommend to use a conda virtual environment.

Thanks for the response. I’ve been using a conda environment – though I did forget to previously uninstall the version I was already working with. Uninstalling them and cleaning made something of a difference, but it’s still failing:

[226/3254] Building CXX object third_party/protobuf/cmake/CMa.../__/src/google/protobuf/compiler/python/python_generator.cc.o
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "setup.py", line 734, in <module>
    build_deps()
  File "setup.py", line 281, in build_deps
    build_dir='build')
  File "/Users/askates/Documents/GitRepos/pytorch/tools/build_pytorch_libs.py", line 248, in build_caffe2
    check_call(ninja_cmd, cwd=build_dir, env=my_env)
  File "/Users/askates/anaconda3/envs/pytorch-1.0/lib/python3.6/subprocess.py", line 291, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['ninja', 'install']' returned non-zero exit status 1.

Can you spot any error messages before these lines of code or could you post the whole build output?

I’ve saved the output here.

I noticed its not finding OpenMP. That definitely sounds like something that could cause a significant slowdown.

It looks like it failing at
confu-deps/QNNPACK/CMakeFiles/qnnpack.dir/src/operator-run.c.o
Could you try to update all submodules, clean, and build it again?
This should update all submodules:
git submodule update --init --recursive

I did as you said, and a few things happened.

I got the following when I tried to update all submodules:

(pytorch-1.0) Alexs-Macbook-2:pytorch askates$ git fetch --all
Fetching origin
(pytorch-1.0) Alexs-Macbook-2:pytorch askates$ git reset --hard origin/master
HEAD is now at 19addc7 Support nonzero onnx export
(pytorch-1.0) Alexs-Macbook-2:pytorch askates$ git submodule update --init --recursive
error: no such remote ref 9884f286a236a3b4e3218e4afa17781752e048bd
Fetched in submodule path 'third_party/onnx-tensorrt', but it did not contain 9884f286a236a3b4e3218e4afa17781752e048bd. Direct fetching of that commit failed.

Trying to build it after gave me this result.

With this error:

ld: warning: object file (/usr/local/cuda/lib/libcudart_static.a(libcudart_static.a.o)) was built for newer OSX version (10.12) than being linked (10.9)
/Users/askates/Documents/GitRepos/pytorch/third_party/QNNPACK/src/operator-run.c:701:7: warning: implicit declaration of
      function 'pthreadpool_compute_3d_tiled' is invalid in C99 [-Wimplicit-function-declaration]
      pthreadpool_compute_3d_tiled(
      ^
/Users/askates/Documents/GitRepos/pytorch/third_party/QNNPACK/src/operator-run.c:703:10: error: use of undeclared identifier
      'pthreadpool_function_3d_tiled_t'
        (pthreadpool_function_3d_tiled_t) compute_sum_rows,
         ^
/Users/askates/Documents/GitRepos/pytorch/third_party/QNNPACK/src/operator-run.c:725:7: warning: implicit declaration of
      function 'pthreadpool_compute_4d_tiled' is invalid in C99 [-Wimplicit-function-declaration]
      pthreadpool_compute_4d_tiled(
      ^
/Users/askates/Documents/GitRepos/pytorch/third_party/QNNPACK/src/operator-run.c:727:12: error: use of undeclared identifier
      'pthreadpool_function_4d_tiled_t'
          (pthreadpool_function_4d_tiled_t) compute_q8gemm_xzp,
           ^
/Users/askates/Documents/GitRepos/pytorch/third_party/QNNPACK/src/operator-run.c:762:12: error: use of undeclared identifier
      'pthreadpool_function_4d_tiled_t'
          (pthreadpool_function_4d_tiled_t) compute_q8gemm,
           ^
/Users/askates/Documents/GitRepos/pytorch/third_party/QNNPACK/src/operator-run.c:802:12: error: use of undeclared identifier
      'pthreadpool_function_4d_tiled_t'
          (pthreadpool_function_4d_tiled_t) compute_q8conv,
           ^
2 warnings and 4 errors generated.
make[2]: *** [confu-deps/QNNPACK/CMakeFiles/qnnpack.dir/src/operator-run.c.o] Error 1
make[1]: *** [confu-deps/QNNPACK/CMakeFiles/qnnpack.dir/all] Error 2
make[1]: *** Waiting for unfinished jobs....
make: *** [all] Error 2
Traceback (most recent call last):
  File "setup.py", line 734, in <module>
    build_deps()
  File "setup.py", line 281, in build_deps
    build_dir='build')
  File "/Users/askates/Documents/GitRepos/pytorch/tools/build_pytorch_libs.py", line 251, in build_caffe2
    check_call(['make', '-j', str(max_jobs), 'install'], cwd=build_dir, env=my_env)
  File "/Users/askates/anaconda3/envs/pytorch-1.0/lib/python3.6/subprocess.py", line 291, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['make', '-j', '8', 'install']' returned non-zero exit status 2.

Edit:

I tried again, except setting USE_CUDA=FALSE.

It got a bit further this time, but still failed.

[ 45%] Linking CXX static library ../../../lib/libprotoc.a
[ 45%] Built target libprotoc
make: *** [all] Error 2
Traceback (most recent call last):
  File "setup.py", line 734, in <module>
    build_deps()
  File "setup.py", line 281, in build_deps
    build_dir='build')
  File "/Users/askates/Documents/GitRepos/pytorch/tools/build_pytorch_libs.py", line 251, in build_caffe2
    check_call(['make', '-j', str(max_jobs), 'install'], cwd=build_dir, env=my_env)
  File "/Users/askates/anaconda3/envs/pytorch-1.0/lib/python3.6/subprocess.py", line 291, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['make', '-j', '8', 'install']' returned non-zero exit status 2.

Before this error snippet there were no errors?
I’m unfortunately not familiar with Mac and have never used one before, so currently I can just guess, what’s going wrong.
I assume you’ve added all the necessary env variables as described here.

Actually I ended up getting it working… deleting my local copy of pytorch and cloning it again enabled me to successfully install it. Only trouble is I’m still experiencing the issue of the same code taking exponentially longer in pytorch 1.1.0a0+19addc7 (which is the version I now am using) than it does in 0.4.0. Is there any way of finding out what the cause might be?

Good to hear it’s working now!
Could you create a code snippet to reproduce the performance issue so that we could have a look?

An example:

import torch
import time
import numpy as np
from torch.nn import functional as F
from torch.distributions.multivariate_normal import MultivariateNormal

_eps = 1e-20

def log_sum_exp(tensor, keepdim=True):
    """
    Numerically stable implementation for the `LogSumExp` operation. The
    summing is done along the last dimension.
    """
    max_val = tensor.max(dim=-1, keepdim=True)[0]
    return max_val + (tensor - max_val).exp().sum(dim=-1, keepdim=keepdim).log()


def loglik_mixture(x, z, c):
    """
    Calculate the log likelihood of a mixture of distributions
    """
    # Reshape the data
    x_reshape = x.contiguous().view(-1, x.shape[-1]).unsqueeze(0)
    x_repeat = x_reshape.repeat(c.shape[0], 1, 1)
    
    # Reshape the mixture weights
    z_reshape = z.view(-1, c.shape[0]) + _eps
    
    # Observation distribution parameters
    mu = x.data.new_zeros(x.shape[-1])  # zero mean
    sigma = c.unsqueeze(1)

    # Likelihood of each observation
    pdf = MultivariateNormal(mu, sigma).log_prob(x_repeat).t()
    
    # Log sum exp trick to sum the log probs, reshape so back in batch, seq_length
    return log_sum_exp(pdf + z_reshape.log()).reshape(*z.shape[:-1])

n_dim = 8
n_states = 6
n_batch = 40
seq_length = 200

c = torch.tensor(np.array([np.eye(n_dim)/i for i in range(1,n_states+1)]), requires_grad=True).float()
z = torch.randn([n_batch, seq_length, n_states], requires_grad=True).float()
z = F.softmax(z, dim=-1)
x = torch.randn([n_batch, seq_length, n_dim], requires_grad=False).float()

start = time.time()
loss = -loglik_mixture(x, z, c).sum()
print('Forwards pass: ', time.time() - start)

start = time.time()
loss.backward()
print('Backwards pass: ', time.time() - start)

This function seems to be the greatest issue in my code.

In 0.4.0 the forwards pass takes 0.0167 seconds and the backwards pass 0.0160 seconds.

In 1.1.0, the forwards pass takes 0.9668 seconds and the backwards pass 287.109 seconds.

Attached below are the bottleneck reports from 0.4.0 and 1.1.0, respectively:

--------------------------------------------------------------------------------
  Environment Summary
--------------------------------------------------------------------------------
PyTorch 0.4.0 not compiled w/ CUDA
Running with Python 3.6 and 

`pip list` truncated output:
Unable to fetch
--------------------------------------------------------------------------------
  cProfile output
--------------------------------------------------------------------------------
         290113 function calls (285732 primitive calls) in 2.371 seconds

   Ordered by: internal time
   List reduced from 2464 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    32/29    0.611    0.019    0.617    0.021 {built-in method _imp.create_dynamic}
      210    0.592    0.003    0.661    0.003 <frozen importlib._bootstrap_external>:830(get_data)
      981    0.147    0.000    0.147    0.000 {built-in method posix.stat}
        1    0.128    0.128    0.128    0.128 {method 'run_backward' of 'torch._C._EngineBase' objects}
       46    0.079    0.002    0.079    0.002 {built-in method io.open}
      210    0.069    0.000    0.069    0.000 {method 'read' of '_io.FileIO' objects}
        3    0.061    0.020    0.061    0.020 {built-in method numpy.core.multiarray.fromfile}
      210    0.048    0.000    0.048    0.000 {built-in method marshal.loads}
     1200    0.041    0.000    0.121    0.000 /Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py:48(GRUCell)
     2400    0.039    0.000    0.039    0.000 {built-in method addmm}
        2    0.030    0.015    0.049    0.024 /Users/askates/anaconda3/envs/torch/lib/python3.6/site-packages/torch/distributions/multivariate_normal.py:67(_batch_mahalanobis)
  901/848    0.024    0.000    0.065    0.000 {built-in method builtins.__build_class__}
      357    0.020    0.000    0.020    0.000 {method 'astype' of 'numpy.ndarray' objects}
        1    0.016    0.016    0.400    0.400 /Users/askates/Documents/Projects/model/src/train.py:94(train)
   410/98    0.015    0.000    0.024    0.000 /Users/askates/anaconda3/envs/torch/lib/python3.6/sre_parse.py:470(_parse)


--------------------------------------------------------------------------------
  autograd profiler output (CPU mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

------------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                       CPU time        CUDA time            Calls        CPU total       CUDA total
------------------  ---------------  ---------------  ---------------  ---------------  ---------------
mul                     14247.185us          0.000us                1      14247.185us          0.000us
mul                     13269.355us          0.000us                1      13269.355us          0.000us
MulBackward1             8679.297us          0.000us                1       8679.297us          0.000us
mul                      8674.418us          0.000us                1       8674.418us          0.000us
sum                      5864.228us          0.000us                1       5864.228us          0.000us
_sum                     5862.115us          0.000us                1       5862.115us          0.000us
sum                      5621.976us          0.000us                1       5621.976us          0.000us
_sum                     5620.054us          0.000us                1       5620.054us          0.000us
ExpandBackward           3796.324us          0.000us                1       3796.324us          0.000us
sum                      3794.368us          0.000us                1       3794.368us          0.000us
_sum                     3792.730us          0.000us                1       3792.730us          0.000us
sub                      1926.867us          0.000us                1       1926.867us          0.000us
PowBackward0             1900.918us          0.000us                1       1900.918us          0.000us
sub                      1782.307us          0.000us                1       1782.307us          0.000us
repeat                   1180.078us          0.000us                1       1180.078us          0.000us
--------------------------------------------------------------------------------
  Environment Summary
--------------------------------------------------------------------------------
PyTorch 1.1.0a0+19addc7 compiled w/ CUDA 9.2
Running with Python 3.6 and CUDA 9.2.148

`pip3 list` truncated output:
numpy==1.14.4
torch==0.4.0
--------------------------------------------------------------------------------
  cProfile output
--------------------------------------------------------------------------------
         684208 function calls (678739 primitive calls) in 485.558 seconds

   Ordered by: internal time
   List reduced from 2460 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1  479.227  479.227  479.227  479.227 {method 'run_backward' of 'torch._C._EngineBase' objects}
      210    1.310    0.006    1.359    0.006 <frozen importlib._bootstrap_external>:830(get_data)
   128000    1.246    0.000    1.246    0.000 {built-in method trtrs}
    35/34    0.958    0.027    0.959    0.028 {built-in method _imp.create_dynamic}
   256000    0.701    0.000    0.701    0.000 /Users/askates/anaconda3/envs/pytorch-1.0/lib/python3.6/site-packages/torch/tensor.py:434(<lambda>)
        1    0.339    0.339  482.546  482.546 /Users/askates/Documents/Projects/model/src/train.py:94(train)
        2    0.313    0.157    2.260    1.130 /Users/askates/anaconda3/envs/pytorch-1.0/lib/python3.6/site-packages/torch/distributions/multivariate_normal.py:30(<listcomp>)
        4    0.295    0.074    0.295    0.074 {built-in method stack}
        4    0.280    0.070    0.280    0.070 {built-in method gru}
      968    0.100    0.000    0.100    0.000 {built-in method posix.stat}
       39    0.089    0.002    0.089    0.002 {built-in method io.open}
        2    0.062    0.031    2.616    1.308 /Users/askates/anaconda3/envs/pytorch-1.0/lib/python3.6/site-packages/torch/distributions/multivariate_normal.py:23(_batch_trtrs_lower)
      210    0.057    0.000    0.057    0.000 {built-in method marshal.loads}
      210    0.049    0.000    0.049    0.000 {method 'read' of '_io.FileIO' objects}
       14    0.035    0.002    0.035    0.002 {method 'reshape' of 'torch._C._TensorBase' objects}


--------------------------------------------------------------------------------
  autograd profiler output (CPU mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

------------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                       CPU time        CUDA time            Calls        CPU total       CUDA total
------------------  ---------------  ---------------  ---------------  ---------------  ---------------
stack                  144361.000us          0.000us                1     144361.000us          0.000us
stack                  137428.000us          0.000us                1     137428.000us          0.000us
StackBackward           95785.000us          0.000us                1      95785.000us          0.000us
unbind                  93179.000us          0.000us                1      93179.000us          0.000us
gru                     48150.000us          0.000us                1      48150.000us          0.000us
gru                     42979.000us          0.000us                1      42979.000us          0.000us
gru                     29062.000us          0.000us                1      29062.000us          0.000us
gru                     24434.000us          0.000us                1      24434.000us          0.000us
reshape                 22573.000us          0.000us                1      22573.000us          0.000us
clone                   22552.000us          0.000us                1      22552.000us          0.000us
reshape                 17592.000us          0.000us                1      17592.000us          0.000us
clone                   17574.000us          0.000us                1      17574.000us          0.000us
add                     17236.000us          0.000us                1      17236.000us          0.000us
SelectBackward          15354.000us          0.000us                1      15354.000us          0.000us
zeros                   15124.000us          0.000us                1      15124.000us          0.000us

--------------------------------------------------------------------------------
  autograd profiler output (CUDA mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

        Because the autograd profiler uses the CUDA event API,
        the CUDA time column reports approximately max(cuda_time, cpu_time).
        Please ignore this output if your code does not use CUDA.

-----------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                      CPU time        CUDA time            Calls        CPU total       CUDA total
-----------------  ---------------  ---------------  ---------------  ---------------  ---------------
stack                 221141.000us     221184.000us                1     221141.000us     221184.000us
stack                 149960.000us     150242.000us                1     149960.000us     150242.000us
gru                   143407.000us     143360.000us                1     143407.000us     143360.000us
StackBackward          95310.000us      95320.750us                1      95310.000us      95320.750us
unbind                 92935.000us      93348.500us                1      92935.000us      93348.500us
gru                    87878.000us      87873.500us                1      87878.000us      87873.500us
gru                    77095.000us      77952.000us                1      77095.000us      77952.000us
gru                    43305.000us      43295.469us                1      43305.000us      43295.469us
add                    34832.000us      34960.000us                1      34832.000us      34960.000us
reshape                31923.000us      31936.000us                1      31923.000us      31936.000us
clone                  31876.000us      31936.000us                1      31876.000us      31936.000us
add                    29661.000us      29696.000us                1      29661.000us      29696.000us
add                    22455.000us      22528.000us                1      22455.000us      22528.000us
add                    21128.000us      21248.000us                1      21128.000us      21248.000us
add                    20329.000us      23200.000us                1      20329.000us      23200.000us
1 Like

Also the full traces here:

0.4.0 https://ufile.io/73ief

1.1.0 https://ufile.io/t0b1x

Among the things that stand out is that the backwards pass of 1.1.0 seems to be made up of almost entirely blocks of TrtrsBackward, followed by SelectBackward and add, versus 0.4.0 which is using MulBackward, PowBackwards, etc. Does this point towards a LAPACK issue?

I could reproduce this issue on 1.0.0.dev20190207 and 1.1.0a0+7343e47, while 0.4.1 runs fast.
Also, I tried to run torch.utils.bottleneck, but apparently more than 24GB of memory was allocated and the process run OOM. Was this also the case for your runs?

Thanks for the script. I’ll try to debug it a bit further.

1 Like

I didn’t monitor the memory usage, but it didn’t seem to go out of memory for me when I ran it with bottleneck. Thanks for looking into it!

To narrow this down:
Replacing the pdf-assignment with the following (nonsensical, but resulting in the same output shape and connecting the backward graph - the shapes for mu, sigma, x_repeat, pdf are (8,), (6, 1, 8, 8), (6, 8000, 8), (8000, 6) respectively):

pdf = torch.einsum('i,jxii,jni->nj',mu, sigma, x_repeat)

gets rid of the slow part. I think there are some known (and possibly some not known) performance issues with distributions.

Best regards

Thomas

P.S.: Minimal example:

mu = torch.randn(8)
sigma = torch.eye(8).expand(6, 1, 8, 8).contiguous().requires_grad_()
x_repeat = torch.randn(6,8000,8)
pdf = MultivariateNormal(mu, sigma).log_prob(x_repeat).t()
pdf.sum().backward()
1 Like

Building with omp on mac is a little funky. It has improved a lot, but is still not perfect. However, macos 0.4 didn’t build with omp iirc, so that shouldn’t be your problem.

So the time is spent in log_probs -> _batch_mahalanobis -> _batch_trtrs_lower which does the following - which seems a bad idea:

flat_X = torch.stack([torch.trtrs(b, A, upper=False)[0] for b, A in zip(flat_b, flat_A)])

The minimal example based on your inputs would be

bb = torch.randn(48000, 8, 1)
bL = torch.randn(48000, 8, 8, requires_grad=True)
res = torch.distributions.multivariate_normal._batch_trtrs_lower(bb, bL)
res.sum().backward()

The 0.4 version of _batch_mahalanobis used torch.inverse instead of trtrs (we didn’t have batch inverse back then so it did the stacking, too). While that would appear to be more overhead in general, it seems much more efficient here, if only because of the administration involved in trtrs (and I never know what to make of the second return value in trtrs).
Here is the bug report.

3 Likes

Implementing the suggestions in the bug report, namely changing the shape of x to account for the batch shaping has made a huge difference!

I now get:

0.4.0
Forwards pass:  0.05098533630371094
Backwards pass:  0.02379775047302246

1.1.0
Forwards pass:  0.15757107734680176
Backwards pass:  0.9299120903015137

Still not as fast as 0.4 but I’m not complaining!

The >10x slowdown is unfortunate, but it would seem it’s not completely impossible to use anymore…

Thank you for reporting the regression and providing a specific minimal example! This type of report is gold for hunting down bugs.

Best regards

Thomas

Thanks for your help, and to everyone else who helped! @ptrblck

1 Like