Slow torch.equal on GPU (bottleneck of nn.MultiheadAttention)


When running nn.MultiheadAttention on the GPU, most time is spent on performing torch.equal (in

To reproduce

import copy
import cProfile
import timeit
import torch
from torch import nn

a = torch.randn(1024, 2, 256, device='cuda')
b = torch.randn(1024, 2, 256, device='cuda')
_ = torch.matmul(a.transpose(1, 2), b)

self_attention = nn.MultiheadAttention(256, 8).to('cuda')
layers = nn.ModuleList([copy.deepcopy(self_attention) for _ in range(20)])

c = torch.randn(1024, 2, 256, device='cuda')
d = torch.randn(1024, 2, 256, device='cuda')

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time (before self-attention): {t1: .4f} ms")
print(f"Second time (before self-attention): {t2: .4f} ms\n")

def run_test(a, b, layers):
    for layer in layers:
        a = layer(a+b, a+b, a)[0]

cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})
cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})
cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time (after self-attention): {t1: .4f} ms")
print(f"Second time (after self-attention): {t2: .4f} ms\n")


First time (before self-attention):  0.1915 ms
Second time (before self-attention):  0.1260 ms

         2187 function calls in 0.121 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.121    0.121 <string>:1(<module>)
       20    0.000    0.000    0.000    0.000
      140    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.119    0.006
        1    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.005    0.000
       80    0.003    0.000    0.008    0.000
       80    0.000    0.000    0.000    0.000<listcomp>)
       20    0.006    0.000    0.118    0.006
       20    0.000    0.000    0.000    0.000<listcomp>)
       20    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.119    0.006
      120    0.000    0.000    0.000    0.000
       40    0.000    0.000    0.000    0.000
      260    0.000    0.000    0.000    0.000<genexpr>)
        1    0.001    0.001    0.120    0.120
       40    0.006    0.000    0.006    0.000 {built-in method bmm}
      140    0.000    0.000    0.000    0.000 {built-in method builtins.any}
        1    0.000    0.000    0.121    0.121 {built-in method builtins.exec}
      160    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       40    0.000    0.000    0.000    0.000 {built-in method builtins.len}
       20    0.000    0.000    0.000    0.000 {built-in method dropout}
       60    0.087    0.001    0.087    0.001 {built-in method equal}
       20    0.000    0.000    0.000    0.000 {built-in method torch._C._get_tracing_state}
       40    0.000    0.000    0.000    0.000 {built-in method torch._C._is_torch_function_enabled}
       80    0.001    0.000    0.001    0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
       80    0.000    0.000    0.000    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       80    0.004    0.000    0.004    0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
      160    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
       20    0.005    0.000    0.005    0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
       20    0.002    0.000    0.002    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
       80    0.001    0.000    0.001    0.000 {method 't' of 'torch._C._TensorBase' objects}
      100    0.001    0.000    0.001    0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
       81    0.000    0.000    0.000    0.000 {method 'values' of 'collections.OrderedDict' objects}
      100    0.001    0.000    0.001    0.000 {method 'view' of 'torch._C._TensorBase' objects}

         2187 function calls in 0.119 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.119    0.119 <string>:1(<module>)
       20    0.000    0.000    0.000    0.000
      140    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.117    0.006
        1    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.001    0.000
       80    0.003    0.000    0.008    0.000
       80    0.000    0.000    0.000    0.000<listcomp>)
       20    0.005    0.000    0.116    0.006
       20    0.000    0.000    0.000    0.000<listcomp>)
       20    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.117    0.006
      120    0.000    0.000    0.000    0.000
       40    0.000    0.000    0.000    0.000
      260    0.000    0.000    0.000    0.000<genexpr>)
        1    0.001    0.001    0.118    0.118
       40    0.002    0.000    0.002    0.000 {built-in method bmm}
      140    0.000    0.000    0.000    0.000 {built-in method builtins.any}
        1    0.000    0.000    0.119    0.119 {built-in method builtins.exec}
      160    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       40    0.000    0.000    0.000    0.000 {built-in method builtins.len}
       20    0.000    0.000    0.000    0.000 {built-in method dropout}
       60    0.097    0.002    0.097    0.002 {built-in method equal}
       20    0.000    0.000    0.000    0.000 {built-in method torch._C._get_tracing_state}
       40    0.000    0.000    0.000    0.000 {built-in method torch._C._is_torch_function_enabled}
       80    0.001    0.000    0.001    0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
       80    0.000    0.000    0.000    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       80    0.004    0.000    0.004    0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
      160    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
       80    0.001    0.000    0.001    0.000 {method 't' of 'torch._C._TensorBase' objects}
      100    0.001    0.000    0.001    0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
       81    0.000    0.000    0.000    0.000 {method 'values' of 'collections.OrderedDict' objects}
      100    0.001    0.000    0.001    0.000 {method 'view' of 'torch._C._TensorBase' objects}

         2187 function calls in 0.122 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.122    0.122 <string>:1(<module>)
       20    0.000    0.000    0.000    0.000
      140    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.120    0.006
        1    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.001    0.000
       80    0.003    0.000    0.008    0.000
       80    0.000    0.000    0.000    0.000<listcomp>)
       20    0.005    0.000    0.119    0.006
       20    0.000    0.000    0.000    0.000<listcomp>)
       20    0.000    0.000    0.000    0.000
       20    0.000    0.000    0.120    0.006
      120    0.000    0.000    0.000    0.000
       40    0.000    0.000    0.000    0.000
      260    0.000    0.000    0.000    0.000<genexpr>)
        1    0.001    0.001    0.121    0.121
       40    0.001    0.000    0.001    0.000 {built-in method bmm}
      140    0.000    0.000    0.000    0.000 {built-in method builtins.any}
        1    0.000    0.000    0.122    0.122 {built-in method builtins.exec}
      160    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       40    0.000    0.000    0.000    0.000 {built-in method builtins.len}
       20    0.000    0.000    0.000    0.000 {built-in method dropout}
       60    0.100    0.002    0.100    0.002 {built-in method equal}
       20    0.000    0.000    0.000    0.000 {built-in method torch._C._get_tracing_state}
       40    0.000    0.000    0.000    0.000 {built-in method torch._C._is_torch_function_enabled}
       80    0.001    0.000    0.001    0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
       80    0.000    0.000    0.000    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       80    0.004    0.000    0.004    0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
      160    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
       80    0.001    0.000    0.001    0.000 {method 't' of 'torch._C._TensorBase' objects}
      100    0.001    0.000    0.001    0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
       81    0.000    0.000    0.000    0.000 {method 'values' of 'collections.OrderedDict' objects}
      100    0.001    0.000    0.001    0.000 {method 'view' of 'torch._C._TensorBase' objects}

First time (after self-attention):  2.8897 ms
Second time (after self-attention):  0.1253 ms


The first two times torch.equal is run, it is quick and all is fine. However, when running the self-attention layers, at some point the torch.equal operation starts taking more time (1st run is slighly faster than following two runs). When running the torch.equal operation right after self-attention, it is indeed clear that torch.equal has significantly slowed down.

I suspect this has to do with some memory movement, where the self-attention layers slowly take more memory resulting in a slow time for torch.equal. It surprises me however that only torch.equal is affected by this. Also, when removing the torch.equal operations in nn.MultiheadAttention by hardcoding the booleans for my use-case, none of the other operations start taking more time. It hence seems that these other operations do not need this slow memory movement as opposed to torch.equal.


PyTorch version: 1.7.0.dev20200903
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Fedora 32 (Thirty Two) (x86_64)
GCC version: (GCC) 10.2.1 20200723 (Red Hat 10.2.1-1)
Clang version: 10.0.0 (Fedora 10.0.0-2.fc32)
CMake version: version 3.17.4

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: GeForce GTX 1050 Ti
Nvidia driver version: 450.66
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] torch==1.7.0.dev20200903
[pip3] torchvision==0.8.0.dev20200903
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py38he904b0f_0  
[conda] mkl_fft                   1.1.0            py38h23d657b_0  
[conda] mkl_random                1.1.1            py38h0573a6f_0  
[conda] numpy                     1.19.1           py38hbc911f0_0  
[conda] numpy-base                1.19.1           py38hfa32c7d_0  
[conda] pytorch                   1.7.0.dev20200903 py3.8_cuda10.2.89_cudnn7.6.5_0    pytorch-nightly
[conda] torchvision               0.8.0.dev20200903      py38_cu102    pytorch-nightly

Additional information

I’ve also tested on a GTX 1660 and a Titan V. Also in both cases, torch.equal was the most time consuming part (taking between 60% and 80%) of the total time. Note also that the equal implementation was recently moved from THC to ATen in PyTorch 1.6 (see However, when running with PyTorch 1.5, I got similar timings.

Perhaps object identity check should be used there instead?

PS github may be a better place for this

Thanks for your response @googlebot . Identity check like c == d (or equivalently torch.eq(c, d)) is not the same as torch.equal(c, d), as torch.eq(c, d) is element-wise. However torch.eq(c, d).all() is equivalent to torch.equal(c, d). Below I test the torch.eq(c, d).all() alternative and it clearly shows that it does not suffer from these slowdowns!

New test

import copy
import timeit
import torch
from torch import nn

# Before self-attention (some computation)
c = torch.randn(1024, 2, 256, device='cuda')
d = torch.randn(1024, 2, 256, device='cuda')

t1 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3

print(f"First time with eq/all (before self-attention): {t1: .4f} ms")
print(f"Second time with eq/all (before self-attention): {t2: .4f} ms")
print(f"Third time with eq/all (before self-attention): {t3: .4f} ms\n")

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time with equal (before self-attention): {t1: .4f} ms")
print(f"Second time with equal (before self-attention): {t2: .4f} ms")
print(f"Third time with equal (before self-attention): {t3: .4f} ms\n")

# Perform self-attention (some computation)
a = torch.randn(1024, 2, 256, device='cuda')
b = torch.randn(1024, 2, 256, device='cuda')
_ = torch.matmul(a.transpose(1, 2), b)

self_attention = nn.MultiheadAttention(256, 8).to('cuda')
layers = nn.ModuleList([copy.deepcopy(self_attention) for _ in range(20)])

for layer in layers:
    a = layer(a+b, a+b, a)[0]

# After self-attention (some computation)
t1 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3

print(f"First time with eq/all (after self-attention): {t1: .4f} ms")
print(f"Second time with eq/all (after self-attention): {t2: .4f} ms")
print(f"Third time with eq/all (after self-attention): {t3: .4f} ms\n")

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time with equal (after self-attention): {t1: .4f} ms")
print(f"Second time with equal (after self-attention): {t2: .4f} ms")
print(f"Third time with equal (after self-attention): {t3: .4f} ms\n")

New results

First time with eq/all (before self-attention):  0.4065 ms
Second time with eq/all (before self-attention):  0.0777 ms
Third time with eq/all (before self-attention):  0.0763 ms

First time with equal (before self-attention):  0.1580 ms
Second time with equal (before self-attention):  0.1254 ms
Third time with equal (before self-attention):  0.1164 ms

First time with eq/all (after self-attention):  0.0825 ms
Second time with eq/all (after self-attention):  0.0775 ms
Third time with eq/all (after self-attention):  0.0776 ms

First time with equal (after self-attention):  4.1039 ms
Second time with equal (after self-attention):  0.1285 ms
Third time with equal (after self-attention):  0.1273 ms


It’s clear that the eq/all alternative does not suffer from this slowdown (first time after self-attention). I believe hence there must be a bug in torch.equal. I will create a corresponding issue on github.

For those interested, the link to the corresponding github issue: