Summary
When running nn.MultiheadAttention
on the GPU, most time is spent on performing torch.equal
(in https://github.com/pytorch/pytorch/blob/7036e91abd6ba5e9d20afa957913a0cd7c08be81/torch/nn/functional.py#L4100).
To reproduce
import copy
import cProfile
import timeit
import torch
from torch import nn
a = torch.randn(1024, 2, 256, device='cuda')
b = torch.randn(1024, 2, 256, device='cuda')
_ = torch.matmul(a.transpose(1, 2), b)
self_attention = nn.MultiheadAttention(256, 8).to('cuda')
layers = nn.ModuleList([copy.deepcopy(self_attention) for _ in range(20)])
c = torch.randn(1024, 2, 256, device='cuda')
d = torch.randn(1024, 2, 256, device='cuda')
t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
print(f"First time (before self-attention): {t1: .4f} ms")
print(f"Second time (before self-attention): {t2: .4f} ms\n")
def run_test(a, b, layers):
for layer in layers:
a = layer(a+b, a+b, a)[0]
cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})
cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})
cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})
t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
print(f"First time (after self-attention): {t1: .4f} ms")
print(f"Second time (after self-attention): {t2: .4f} ms\n")
Results
First time (before self-attention): 0.1915 ms
Second time (before self-attention): 0.1260 ms
2187 function calls in 0.121 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.121 0.121 <string>:1(<module>)
20 0.000 0.000 0.000 0.000 _VF.py:25(__getattr__)
140 0.000 0.000 0.000 0.000 _jit_internal.py:713(is_scripting)
20 0.000 0.000 0.119 0.006 activation.py:923(forward)
1 0.000 0.000 0.000 0.000 container.py:184(__iter__)
20 0.000 0.000 0.005 0.000 functional.py:1473(softmax)
80 0.003 0.000 0.008 0.000 functional.py:1663(linear)
80 0.000 0.000 0.000 0.000 functional.py:1678(<listcomp>)
20 0.006 0.000 0.118 0.006 functional.py:3860(multi_head_attention_forward)
20 0.000 0.000 0.000 0.000 functional.py:3944(<listcomp>)
20 0.000 0.000 0.000 0.000 functional.py:950(dropout)
20 0.000 0.000 0.119 0.006 module.py:715(_call_impl)
120 0.000 0.000 0.000 0.000 module.py:765(__getattr__)
40 0.000 0.000 0.000 0.000 overrides.py:1058(has_torch_function)
260 0.000 0.000 0.000 0.000 overrides.py:1071(<genexpr>)
1 0.001 0.001 0.120 0.120 test_equal.py:24(run_test)
40 0.006 0.000 0.006 0.000 {built-in method bmm}
140 0.000 0.000 0.000 0.000 {built-in method builtins.any}
1 0.000 0.000 0.121 0.121 {built-in method builtins.exec}
160 0.000 0.000 0.000 0.000 {built-in method builtins.getattr}
1 0.000 0.000 0.000 0.000 {built-in method builtins.iter}
40 0.000 0.000 0.000 0.000 {built-in method builtins.len}
20 0.000 0.000 0.000 0.000 {built-in method dropout}
60 0.087 0.001 0.087 0.001 {built-in method equal}
20 0.000 0.000 0.000 0.000 {built-in method torch._C._get_tracing_state}
40 0.000 0.000 0.000 0.000 {built-in method torch._C._is_torch_function_enabled}
80 0.001 0.000 0.001 0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
80 0.000 0.000 0.000 0.000 {method 'dim' of 'torch._C._TensorBase' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
80 0.004 0.000 0.004 0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
160 0.000 0.000 0.000 0.000 {method 'size' of 'torch._C._TensorBase' objects}
20 0.005 0.000 0.005 0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
20 0.002 0.000 0.002 0.000 {method 'sum' of 'torch._C._TensorBase' objects}
80 0.001 0.000 0.001 0.000 {method 't' of 'torch._C._TensorBase' objects}
100 0.001 0.000 0.001 0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
81 0.000 0.000 0.000 0.000 {method 'values' of 'collections.OrderedDict' objects}
100 0.001 0.000 0.001 0.000 {method 'view' of 'torch._C._TensorBase' objects}
2187 function calls in 0.119 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.119 0.119 <string>:1(<module>)
20 0.000 0.000 0.000 0.000 _VF.py:25(__getattr__)
140 0.000 0.000 0.000 0.000 _jit_internal.py:713(is_scripting)
20 0.000 0.000 0.117 0.006 activation.py:923(forward)
1 0.000 0.000 0.000 0.000 container.py:184(__iter__)
20 0.000 0.000 0.001 0.000 functional.py:1473(softmax)
80 0.003 0.000 0.008 0.000 functional.py:1663(linear)
80 0.000 0.000 0.000 0.000 functional.py:1678(<listcomp>)
20 0.005 0.000 0.116 0.006 functional.py:3860(multi_head_attention_forward)
20 0.000 0.000 0.000 0.000 functional.py:3944(<listcomp>)
20 0.000 0.000 0.000 0.000 functional.py:950(dropout)
20 0.000 0.000 0.117 0.006 module.py:715(_call_impl)
120 0.000 0.000 0.000 0.000 module.py:765(__getattr__)
40 0.000 0.000 0.000 0.000 overrides.py:1058(has_torch_function)
260 0.000 0.000 0.000 0.000 overrides.py:1071(<genexpr>)
1 0.001 0.001 0.118 0.118 test_equal.py:24(run_test)
40 0.002 0.000 0.002 0.000 {built-in method bmm}
140 0.000 0.000 0.000 0.000 {built-in method builtins.any}
1 0.000 0.000 0.119 0.119 {built-in method builtins.exec}
160 0.000 0.000 0.000 0.000 {built-in method builtins.getattr}
1 0.000 0.000 0.000 0.000 {built-in method builtins.iter}
40 0.000 0.000 0.000 0.000 {built-in method builtins.len}
20 0.000 0.000 0.000 0.000 {built-in method dropout}
60 0.097 0.002 0.097 0.002 {built-in method equal}
20 0.000 0.000 0.000 0.000 {built-in method torch._C._get_tracing_state}
40 0.000 0.000 0.000 0.000 {built-in method torch._C._is_torch_function_enabled}
80 0.001 0.000 0.001 0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
80 0.000 0.000 0.000 0.000 {method 'dim' of 'torch._C._TensorBase' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
80 0.004 0.000 0.004 0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
160 0.000 0.000 0.000 0.000 {method 'size' of 'torch._C._TensorBase' objects}
20 0.001 0.000 0.001 0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
20 0.001 0.000 0.001 0.000 {method 'sum' of 'torch._C._TensorBase' objects}
80 0.001 0.000 0.001 0.000 {method 't' of 'torch._C._TensorBase' objects}
100 0.001 0.000 0.001 0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
81 0.000 0.000 0.000 0.000 {method 'values' of 'collections.OrderedDict' objects}
100 0.001 0.000 0.001 0.000 {method 'view' of 'torch._C._TensorBase' objects}
2187 function calls in 0.122 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.122 0.122 <string>:1(<module>)
20 0.000 0.000 0.000 0.000 _VF.py:25(__getattr__)
140 0.000 0.000 0.000 0.000 _jit_internal.py:713(is_scripting)
20 0.000 0.000 0.120 0.006 activation.py:923(forward)
1 0.000 0.000 0.000 0.000 container.py:184(__iter__)
20 0.000 0.000 0.001 0.000 functional.py:1473(softmax)
80 0.003 0.000 0.008 0.000 functional.py:1663(linear)
80 0.000 0.000 0.000 0.000 functional.py:1678(<listcomp>)
20 0.005 0.000 0.119 0.006 functional.py:3860(multi_head_attention_forward)
20 0.000 0.000 0.000 0.000 functional.py:3944(<listcomp>)
20 0.000 0.000 0.000 0.000 functional.py:950(dropout)
20 0.000 0.000 0.120 0.006 module.py:715(_call_impl)
120 0.000 0.000 0.000 0.000 module.py:765(__getattr__)
40 0.000 0.000 0.000 0.000 overrides.py:1058(has_torch_function)
260 0.000 0.000 0.000 0.000 overrides.py:1071(<genexpr>)
1 0.001 0.001 0.121 0.121 test_equal.py:24(run_test)
40 0.001 0.000 0.001 0.000 {built-in method bmm}
140 0.000 0.000 0.000 0.000 {built-in method builtins.any}
1 0.000 0.000 0.122 0.122 {built-in method builtins.exec}
160 0.000 0.000 0.000 0.000 {built-in method builtins.getattr}
1 0.000 0.000 0.000 0.000 {built-in method builtins.iter}
40 0.000 0.000 0.000 0.000 {built-in method builtins.len}
20 0.000 0.000 0.000 0.000 {built-in method dropout}
60 0.100 0.002 0.100 0.002 {built-in method equal}
20 0.000 0.000 0.000 0.000 {built-in method torch._C._get_tracing_state}
40 0.000 0.000 0.000 0.000 {built-in method torch._C._is_torch_function_enabled}
80 0.001 0.000 0.001 0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
80 0.000 0.000 0.000 0.000 {method 'dim' of 'torch._C._TensorBase' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
80 0.004 0.000 0.004 0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
160 0.000 0.000 0.000 0.000 {method 'size' of 'torch._C._TensorBase' objects}
20 0.001 0.000 0.001 0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
20 0.001 0.000 0.001 0.000 {method 'sum' of 'torch._C._TensorBase' objects}
80 0.001 0.000 0.001 0.000 {method 't' of 'torch._C._TensorBase' objects}
100 0.001 0.000 0.001 0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
81 0.000 0.000 0.000 0.000 {method 'values' of 'collections.OrderedDict' objects}
100 0.001 0.000 0.001 0.000 {method 'view' of 'torch._C._TensorBase' objects}
First time (after self-attention): 2.8897 ms
Second time (after self-attention): 0.1253 ms
Interpretation
The first two times torch.equal
is run, it is quick and all is fine. However, when running the self-attention layers, at some point the torch.equal
operation starts taking more time (1st run is slighly faster than following two runs). When running the torch.equal
operation right after self-attention, it is indeed clear that torch.equal
has significantly slowed down.
I suspect this has to do with some memory movement, where the self-attention layers slowly take more memory resulting in a slow time for torch.equal
. It surprises me however that only torch.equal
is affected by this. Also, when removing the torch.equal
operations in nn.MultiheadAttention by hardcoding the booleans for my use-case, none of the other operations start taking more time. It hence seems that these other operations do not need this slow memory movement as opposed to torch.equal
.
Environment
PyTorch version: 1.7.0.dev20200903
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Fedora 32 (Thirty Two) (x86_64)
GCC version: (GCC) 10.2.1 20200723 (Red Hat 10.2.1-1)
Clang version: 10.0.0 (Fedora 10.0.0-2.fc32)
CMake version: version 3.17.4
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: GeForce GTX 1050 Ti
Nvidia driver version: 450.66
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] torch==1.7.0.dev20200903
[pip3] torchvision==0.8.0.dev20200903
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] mkl 2020.2 256
[conda] mkl-service 2.3.0 py38he904b0f_0
[conda] mkl_fft 1.1.0 py38h23d657b_0
[conda] mkl_random 1.1.1 py38h0573a6f_0
[conda] numpy 1.19.1 py38hbc911f0_0
[conda] numpy-base 1.19.1 py38hfa32c7d_0
[conda] pytorch 1.7.0.dev20200903 py3.8_cuda10.2.89_cudnn7.6.5_0 pytorch-nightly
[conda] torchvision 0.8.0.dev20200903 py38_cu102 pytorch-nightly
Additional information
I’ve also tested on a GTX 1660 and a Titan V. Also in both cases, torch.equal
was the most time consuming part (taking between 60% and 80%) of the total time. Note also that the equal
implementation was recently moved from THC to ATen in PyTorch 1.6 (see https://github.com/pytorch/pytorch/pull/36483). However, when running with PyTorch 1.5, I got similar timings.