Slow torch.equal on GPU (bottleneck of nn.MultiheadAttention)

Summary

When running nn.MultiheadAttention on the GPU, most time is spent on performing torch.equal (in https://github.com/pytorch/pytorch/blob/7036e91abd6ba5e9d20afa957913a0cd7c08be81/torch/nn/functional.py#L4100).

To reproduce

import copy
import cProfile
import timeit
import torch
from torch import nn

a = torch.randn(1024, 2, 256, device='cuda')
b = torch.randn(1024, 2, 256, device='cuda')
_ = torch.matmul(a.transpose(1, 2), b)

self_attention = nn.MultiheadAttention(256, 8).to('cuda')
layers = nn.ModuleList([copy.deepcopy(self_attention) for _ in range(20)])

c = torch.randn(1024, 2, 256, device='cuda')
d = torch.randn(1024, 2, 256, device='cuda')

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time (before self-attention): {t1: .4f} ms")
print(f"Second time (before self-attention): {t2: .4f} ms\n")


def run_test(a, b, layers):
    for layer in layers:
        a = layer(a+b, a+b, a)[0]


cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})
cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})
cProfile.runctx('run_test(a, b, layers)', globals=globals(), locals={})

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time (after self-attention): {t1: .4f} ms")
print(f"Second time (after self-attention): {t2: .4f} ms\n")

Results

First time (before self-attention):  0.1915 ms
Second time (before self-attention):  0.1260 ms

         2187 function calls in 0.121 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.121    0.121 <string>:1(<module>)
       20    0.000    0.000    0.000    0.000 _VF.py:25(__getattr__)
      140    0.000    0.000    0.000    0.000 _jit_internal.py:713(is_scripting)
       20    0.000    0.000    0.119    0.006 activation.py:923(forward)
        1    0.000    0.000    0.000    0.000 container.py:184(__iter__)
       20    0.000    0.000    0.005    0.000 functional.py:1473(softmax)
       80    0.003    0.000    0.008    0.000 functional.py:1663(linear)
       80    0.000    0.000    0.000    0.000 functional.py:1678(<listcomp>)
       20    0.006    0.000    0.118    0.006 functional.py:3860(multi_head_attention_forward)
       20    0.000    0.000    0.000    0.000 functional.py:3944(<listcomp>)
       20    0.000    0.000    0.000    0.000 functional.py:950(dropout)
       20    0.000    0.000    0.119    0.006 module.py:715(_call_impl)
      120    0.000    0.000    0.000    0.000 module.py:765(__getattr__)
       40    0.000    0.000    0.000    0.000 overrides.py:1058(has_torch_function)
      260    0.000    0.000    0.000    0.000 overrides.py:1071(<genexpr>)
        1    0.001    0.001    0.120    0.120 test_equal.py:24(run_test)
       40    0.006    0.000    0.006    0.000 {built-in method bmm}
      140    0.000    0.000    0.000    0.000 {built-in method builtins.any}
        1    0.000    0.000    0.121    0.121 {built-in method builtins.exec}
      160    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       40    0.000    0.000    0.000    0.000 {built-in method builtins.len}
       20    0.000    0.000    0.000    0.000 {built-in method dropout}
       60    0.087    0.001    0.087    0.001 {built-in method equal}
       20    0.000    0.000    0.000    0.000 {built-in method torch._C._get_tracing_state}
       40    0.000    0.000    0.000    0.000 {built-in method torch._C._is_torch_function_enabled}
       80    0.001    0.000    0.001    0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
       80    0.000    0.000    0.000    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       80    0.004    0.000    0.004    0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
      160    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
       20    0.005    0.000    0.005    0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
       20    0.002    0.000    0.002    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
       80    0.001    0.000    0.001    0.000 {method 't' of 'torch._C._TensorBase' objects}
      100    0.001    0.000    0.001    0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
       81    0.000    0.000    0.000    0.000 {method 'values' of 'collections.OrderedDict' objects}
      100    0.001    0.000    0.001    0.000 {method 'view' of 'torch._C._TensorBase' objects}


         2187 function calls in 0.119 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.119    0.119 <string>:1(<module>)
       20    0.000    0.000    0.000    0.000 _VF.py:25(__getattr__)
      140    0.000    0.000    0.000    0.000 _jit_internal.py:713(is_scripting)
       20    0.000    0.000    0.117    0.006 activation.py:923(forward)
        1    0.000    0.000    0.000    0.000 container.py:184(__iter__)
       20    0.000    0.000    0.001    0.000 functional.py:1473(softmax)
       80    0.003    0.000    0.008    0.000 functional.py:1663(linear)
       80    0.000    0.000    0.000    0.000 functional.py:1678(<listcomp>)
       20    0.005    0.000    0.116    0.006 functional.py:3860(multi_head_attention_forward)
       20    0.000    0.000    0.000    0.000 functional.py:3944(<listcomp>)
       20    0.000    0.000    0.000    0.000 functional.py:950(dropout)
       20    0.000    0.000    0.117    0.006 module.py:715(_call_impl)
      120    0.000    0.000    0.000    0.000 module.py:765(__getattr__)
       40    0.000    0.000    0.000    0.000 overrides.py:1058(has_torch_function)
      260    0.000    0.000    0.000    0.000 overrides.py:1071(<genexpr>)
        1    0.001    0.001    0.118    0.118 test_equal.py:24(run_test)
       40    0.002    0.000    0.002    0.000 {built-in method bmm}
      140    0.000    0.000    0.000    0.000 {built-in method builtins.any}
        1    0.000    0.000    0.119    0.119 {built-in method builtins.exec}
      160    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       40    0.000    0.000    0.000    0.000 {built-in method builtins.len}
       20    0.000    0.000    0.000    0.000 {built-in method dropout}
       60    0.097    0.002    0.097    0.002 {built-in method equal}
       20    0.000    0.000    0.000    0.000 {built-in method torch._C._get_tracing_state}
       40    0.000    0.000    0.000    0.000 {built-in method torch._C._is_torch_function_enabled}
       80    0.001    0.000    0.001    0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
       80    0.000    0.000    0.000    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       80    0.004    0.000    0.004    0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
      160    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
       80    0.001    0.000    0.001    0.000 {method 't' of 'torch._C._TensorBase' objects}
      100    0.001    0.000    0.001    0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
       81    0.000    0.000    0.000    0.000 {method 'values' of 'collections.OrderedDict' objects}
      100    0.001    0.000    0.001    0.000 {method 'view' of 'torch._C._TensorBase' objects}


         2187 function calls in 0.122 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.122    0.122 <string>:1(<module>)
       20    0.000    0.000    0.000    0.000 _VF.py:25(__getattr__)
      140    0.000    0.000    0.000    0.000 _jit_internal.py:713(is_scripting)
       20    0.000    0.000    0.120    0.006 activation.py:923(forward)
        1    0.000    0.000    0.000    0.000 container.py:184(__iter__)
       20    0.000    0.000    0.001    0.000 functional.py:1473(softmax)
       80    0.003    0.000    0.008    0.000 functional.py:1663(linear)
       80    0.000    0.000    0.000    0.000 functional.py:1678(<listcomp>)
       20    0.005    0.000    0.119    0.006 functional.py:3860(multi_head_attention_forward)
       20    0.000    0.000    0.000    0.000 functional.py:3944(<listcomp>)
       20    0.000    0.000    0.000    0.000 functional.py:950(dropout)
       20    0.000    0.000    0.120    0.006 module.py:715(_call_impl)
      120    0.000    0.000    0.000    0.000 module.py:765(__getattr__)
       40    0.000    0.000    0.000    0.000 overrides.py:1058(has_torch_function)
      260    0.000    0.000    0.000    0.000 overrides.py:1071(<genexpr>)
        1    0.001    0.001    0.121    0.121 test_equal.py:24(run_test)
       40    0.001    0.000    0.001    0.000 {built-in method bmm}
      140    0.000    0.000    0.000    0.000 {built-in method builtins.any}
        1    0.000    0.000    0.122    0.122 {built-in method builtins.exec}
      160    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       40    0.000    0.000    0.000    0.000 {built-in method builtins.len}
       20    0.000    0.000    0.000    0.000 {built-in method dropout}
       60    0.100    0.002    0.100    0.002 {built-in method equal}
       20    0.000    0.000    0.000    0.000 {built-in method torch._C._get_tracing_state}
       40    0.000    0.000    0.000    0.000 {built-in method torch._C._is_torch_function_enabled}
       80    0.001    0.000    0.001    0.000 {method 'contiguous' of 'torch._C._TensorBase' objects}
       80    0.000    0.000    0.000    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       80    0.004    0.000    0.004    0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
      160    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'softmax' of 'torch._C._TensorBase' objects}
       20    0.001    0.000    0.001    0.000 {method 'sum' of 'torch._C._TensorBase' objects}
       80    0.001    0.000    0.001    0.000 {method 't' of 'torch._C._TensorBase' objects}
      100    0.001    0.000    0.001    0.000 {method 'transpose' of 'torch._C._TensorBase' objects}
       81    0.000    0.000    0.000    0.000 {method 'values' of 'collections.OrderedDict' objects}
      100    0.001    0.000    0.001    0.000 {method 'view' of 'torch._C._TensorBase' objects}


First time (after self-attention):  2.8897 ms
Second time (after self-attention):  0.1253 ms

Interpretation

The first two times torch.equal is run, it is quick and all is fine. However, when running the self-attention layers, at some point the torch.equal operation starts taking more time (1st run is slighly faster than following two runs). When running the torch.equal operation right after self-attention, it is indeed clear that torch.equal has significantly slowed down.

I suspect this has to do with some memory movement, where the self-attention layers slowly take more memory resulting in a slow time for torch.equal. It surprises me however that only torch.equal is affected by this. Also, when removing the torch.equal operations in nn.MultiheadAttention by hardcoding the booleans for my use-case, none of the other operations start taking more time. It hence seems that these other operations do not need this slow memory movement as opposed to torch.equal.

Environment

PyTorch version: 1.7.0.dev20200903
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Fedora 32 (Thirty Two) (x86_64)
GCC version: (GCC) 10.2.1 20200723 (Red Hat 10.2.1-1)
Clang version: 10.0.0 (Fedora 10.0.0-2.fc32)
CMake version: version 3.17.4

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: GeForce GTX 1050 Ti
Nvidia driver version: 450.66
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] torch==1.7.0.dev20200903
[pip3] torchvision==0.8.0.dev20200903
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py38he904b0f_0  
[conda] mkl_fft                   1.1.0            py38h23d657b_0  
[conda] mkl_random                1.1.1            py38h0573a6f_0  
[conda] numpy                     1.19.1           py38hbc911f0_0  
[conda] numpy-base                1.19.1           py38hfa32c7d_0  
[conda] pytorch                   1.7.0.dev20200903 py3.8_cuda10.2.89_cudnn7.6.5_0    pytorch-nightly
[conda] torchvision               0.8.0.dev20200903      py38_cu102    pytorch-nightly

Additional information

I’ve also tested on a GTX 1660 and a Titan V. Also in both cases, torch.equal was the most time consuming part (taking between 60% and 80%) of the total time. Note also that the equal implementation was recently moved from THC to ATen in PyTorch 1.6 (see https://github.com/pytorch/pytorch/pull/36483). However, when running with PyTorch 1.5, I got similar timings.

Perhaps object identity check should be used there instead?

PS github may be a better place for this

Thanks for your response @googlebot . Identity check like c == d (or equivalently torch.eq(c, d)) is not the same as torch.equal(c, d), as torch.eq(c, d) is element-wise. However torch.eq(c, d).all() is equivalent to torch.equal(c, d). Below I test the torch.eq(c, d).all() alternative and it clearly shows that it does not suffer from these slowdowns!

New test

import copy
import timeit
import torch
from torch import nn

# Before self-attention (some computation)
c = torch.randn(1024, 2, 256, device='cuda')
d = torch.randn(1024, 2, 256, device='cuda')

t1 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3

print(f"First time with eq/all (before self-attention): {t1: .4f} ms")
print(f"Second time with eq/all (before self-attention): {t2: .4f} ms")
print(f"Third time with eq/all (before self-attention): {t3: .4f} ms\n")

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time with equal (before self-attention): {t1: .4f} ms")
print(f"Second time with equal (before self-attention): {t2: .4f} ms")
print(f"Third time with equal (before self-attention): {t3: .4f} ms\n")

# Perform self-attention (some computation)
a = torch.randn(1024, 2, 256, device='cuda')
b = torch.randn(1024, 2, 256, device='cuda')
_ = torch.matmul(a.transpose(1, 2), b)

self_attention = nn.MultiheadAttention(256, 8).to('cuda')
layers = nn.ModuleList([copy.deepcopy(self_attention) for _ in range(20)])

for layer in layers:
    a = layer(a+b, a+b, a)[0]

# After self-attention (some computation)
t1 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.eq(c, d).all()', number=1, globals=globals())*1e3

print(f"First time with eq/all (after self-attention): {t1: .4f} ms")
print(f"Second time with eq/all (after self-attention): {t2: .4f} ms")
print(f"Third time with eq/all (after self-attention): {t3: .4f} ms\n")

t1 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t2 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3
t3 = timeit.timeit('torch.equal(c, d)', number=1, globals=globals())*1e3

print(f"First time with equal (after self-attention): {t1: .4f} ms")
print(f"Second time with equal (after self-attention): {t2: .4f} ms")
print(f"Third time with equal (after self-attention): {t3: .4f} ms\n")

New results

First time with eq/all (before self-attention):  0.4065 ms
Second time with eq/all (before self-attention):  0.0777 ms
Third time with eq/all (before self-attention):  0.0763 ms

First time with equal (before self-attention):  0.1580 ms
Second time with equal (before self-attention):  0.1254 ms
Third time with equal (before self-attention):  0.1164 ms

First time with eq/all (after self-attention):  0.0825 ms
Second time with eq/all (after self-attention):  0.0775 ms
Third time with eq/all (after self-attention):  0.0776 ms

First time with equal (after self-attention):  4.1039 ms
Second time with equal (after self-attention):  0.1285 ms
Third time with equal (after self-attention):  0.1273 ms

Interpretation

It’s clear that the eq/all alternative does not suffer from this slowdown (first time after self-attention). I believe hence there must be a bug in torch.equal. I will create a corresponding issue on github.

For those interested, the link to the corresponding github issue: