Quantized model of dynamic quantization on BERT tutorial is slower than original model;

Hello, I am new in Deep Learning and Pytorch.

I’m interested in making fast deep-learning model.
So I have tried to run dynamic quantized model on BERT tutorial in pytorch.org.

I had program run on Intel Xeon E5-2620 v4 system,
and checked that the quantized model is smaller than original model(438M → 181.5M).
but totall-evalluate time of quantized model is slower than original model(122.3 → 123.2);
I had program run on same-spec but different computer, result was same;

I also had program run on AMD ryzen 2600, ryzen 3950x system,
in this case, execution speed was slower than intel system,
but totall-evalutate quantized model was faster than original model;

image

As given in the tutorial here: https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html The size reduction from 438M -> 181.5M matches. For the time, we expect approximately a 2x performance improvement on Xeon E5-2620 v4. Can’t say what’s the issue without access to your env setup.

Hi,
I’m in trouble with the same case.
No speed-up.
I ran the bert tutorial in a conda env on iMac.

FP32: 437.98 MB, 106.6 sec, acc&f1 0.8811
QINT8: 181.44 MB, 108.2 sec, acc&f1 0.8799

python -V
3.8.5

torch.__version__ 
1.6.0

<conda info>

    active environment : quant_bert
    active env location : /Users/*******/anaconda/envs/quant_bert
            shell level : 1
       user config file : /Users/*******/.condarc
 populated config files :
          conda version : 4.5.11
    conda-build version : 2.0.2
         python version : 3.5.6.final.0
       base environment : /Users/*******/anaconda  (writable)
           channel URLs : https://repo.anaconda.com/pkgs/main/osx-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/free/osx-64
                          https://repo.anaconda.com/pkgs/free/noarch
                          https://repo.anaconda.com/pkgs/r/osx-64
                          https://repo.anaconda.com/pkgs/r/noarch
                          https://repo.anaconda.com/pkgs/pro/osx-64
                          https://repo.anaconda.com/pkgs/pro/noarch
          package cache : /Users/*******/anaconda/pkgs
                          /Users/*******/.conda/pkgs
       envs directories : /Users/*******/anaconda/envs
                          /Users/*******/.conda/envs
               platform : osx-64
             user-agent : conda/4.5.11 requests/2.14.2 CPython/3.5.6 Darwin/19.6.0 OSX/10.15.6
<H/W spec.>
macOS Catalina 10.15.6
iMac (Retina 5K, 27-inch, Late 2014) 
Quad-Core Intel Core i5, 3.5GHz
32GB 1600 MHz DDR3
AMD Radeon R9 M290X 2 GB
L2 cache: 256 KB / core
L3 cache: 6 MB

Thanks.

Thanks;

My env setup of Intel system is as follows

conda info

  active environment : danchu
    active env location : /home/danchu/anaconda3/envs/danchu
             shell level : 1
        user config file : /home/danchu/.condarc
  populated config files : /home/danchu/.condarc
           conda version : 4.8.3
     conda-build version : 3.18.11
          python version : 3.8.3.final.0
        virtual packages : __cuda=10.1
                           __glibc=2.17
        base environment : /home/danchu/anaconda3  (writable)
            channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
                           https://repo.anaconda.com/pkgs/main/noarch
                           https://repo.anaconda.com/pkgs/r/linux-64
                           https://repo.anaconda.com/pkgs/r/noarch
           package cache : /home/danchu/anaconda3/pkgs
                           /home/danchu/.conda/pkgs
        envs directories : /home/danchu/anaconda3/envs
                           /home/danchu/.conda/envs
                platform : linux-64
              user-agent : conda/4.8.3 requests/2.24.0 CPython/3.8.3 Linux/3.10.0-957.el7.x86_64 centos/7.6.1810 glibc/2.17
                 UID:GID : 1994:1994
              netrc file : None            
            offline mode : False

package version

numpy              1.19.2
torch              1.8.0.dev20201014+cu101
transformers       3.3.1

torch._config_.parallel_info()

1.8.0.dev20201014+cu101
ATen/Parallel:
        at::get_num_threads() : 1
        at::get_num_interop_threads() : 16
OpenMP 201511 (a.k.a. OpenMP 4.5)
        omp_get_max_threads() : 1
Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
        mkl_get_max_threads() : 1
Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
std::thread::hardware_concurrency() : 32
Environment variables:
        OMP_NUM_THREADS : [not set]
        MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP

torch._config_.show()

PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

HW SPEC

     cpu : Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz x 2
mem size : 65673788 kB
     gpu : Geforce GTX 1080ti x4

Can you two please try with setting OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 ./binary ?

I’ve tried as follows,

export MKL_NUM_THREADS = 1
export OMP_NUM_THREADS = 1
python Bert_quantze_tutorial.py

touch.config.parallel_info()

1.8.0.dev20201014+cu101
ATen/Parallel:
        at::get_num_threads() : 1
        at::get_num_interop_threads() : 16
OpenMP 201511 (a.k.a. OpenMP 4.5)
        omp_get_max_threads() : 1
Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
        mkl_get_max_threads() : 1
Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
std::thread::hardware_concurrency() : 32
Environment variables:
        OMP_NUM_THREADS : 1
        MKL_NUM_THREADS : 1

But can’t get any speed up with this setting;

Result

Size (MB): 438.021641
Size (MB): 181.502781
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████| 51/51 [02:00<00:00,  2.37s/it]
{'acc': 0.8602941176470589, 'f1': 0.9018932874354562, 'acc_and_f1': 0.8810937025412575}
Evaluate total time (seconds): 120.8
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████| 51/51 [02:01<00:00,  2.39s/it]
{'acc': 0.8578431372549019, 'f1': 0.8999999999999999, 'acc_and_f1': 0.878921568627451}
Evaluate total time (seconds): 122.0

Thanks.

1] What do you mean “.binary”?

2] Before applying your comment,
torch.__config__.parallel_info()

ATen/Parallel:
    at::get_num_threads() : 1
    at::get_num_interop_threads() : 2
OpenMP not found
Intel(R) Math Kernel Library Version 2019.0.4 Product Build 20190411 for Intel(R) 64 architecture applications
    mkl_get_max_threads() : 4
Intel(R) MKL-DNN v1.5.0 (Git Hash e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
std::thread::hardware_concurrency() : 4
Environment variables:
    OMP_NUM_THREADS : [not set]
    MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP

Size (MB): 438.017609
Size (MB): 181.499089
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 51/51 [01:52<00:00,  2.21s/it]
{'acc': 0.8602941176470589, 'f1': 0.9018932874354562, 'acc_and_f1': 0.8810937025412575}
Evaluate total time (seconds): 112.9
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 51/51 [01:51<00:00,  2.18s/it]
{'acc': 0.8578431372549019, 'f1': 0.8993055555555555, 'acc_and_f1': 0.8785743464052287}
Evaluate total time (seconds): 111.2

3] After,

ATen/Parallel:
    at::get_num_threads() : 1
    at::get_num_interop_threads() : 2
OpenMP not found
Intel(R) Math Kernel Library Version 2019.0.4 Product Build 20190411 for Intel(R) 64 architecture applications
    mkl_get_max_threads() : 1
Intel(R) MKL-DNN v1.5.0 (Git Hash e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
std::thread::hardware_concurrency() : 4
Environment variables:
    OMP_NUM_THREADS : 1
    MKL_NUM_THREADS : 1
ATen parallel backend: OpenMP

Size (MB): 438.017609
Size (MB): 181.499089
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 51/51 [01:49<00:00,  2.15s/it]
{'acc': 0.8602941176470589, 'f1': 0.9018932874354562, 'acc_and_f1': 0.8810937025412575}
Evaluate total time (seconds): 110.0
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 51/51 [01:47<00:00,  2.12s/it]
{'acc': 0.8578431372549019, 'f1': 0.8993055555555555, 'acc_and_f1': 0.8785743464052287}
Evaluate total time (seconds): 108.0

I couldn’t get a meaningful speed-up.

p.s) Interface of function convert_examples_to_features from transformers was changed.
So I modified run.py.

I’ve tried another spec Intel machine as follows:

HW SPEC

     cpu : Intel(R) Xeon(R) silver 4110 cpu @ 2.10GHz (skylake) 
mem size : 65673788 kB
     gpu : Geforce GTX 1080ti x4

touch.__config__.parallel_info()

1.8.0.dev20201014+cu101
ATen/Parallel:
        at::get_num_threads() : 1
        at::get_num_interop_threads() : 8
OpenMP 201511 (a.k.a. OpenMP 4.5)
        omp_get_max_threads() : 1
Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
        mkl_get_max_threads() : 1
Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
std::thread::hardware_concurrency() : 16
Environment variables:
        OMP_NUM_THREADS : [not set]
        MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP

touch.__config__.show()

PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

Result

Size (MB): 438.021641
Size (MB): 181.502781
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████| 51/51 [03:19<00:00,  3.92s/it]
{'acc': 0.8602941176470589, 'f1': 0.9018932874354562, 'acc_and_f1': 0.8810937025412575}
Evaluate total time (seconds): 200.0
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████| 51/51 [01:57<00:00,  2.31s/it]
{'acc': 0.8578431372549019, 'f1': 0.8999999999999999, 'acc_and_f1': 0.878921568627451}
Evaluate total time (seconds): 117.6

In this case, I can get considerable speed up with same env_setting&code;

I think that Intel broadwell&haswell processors are related with this issue;

Micro-architecture of CPU seems to matter.
My Macbook has a CPU based on Kaby Lake (successor of Sky Lake),
and shows an expected speed-up at the quantized mode.

1 Like

One thing that I missed to talk about;

Another tutorial “DYNAMIC QUANTIZATION ON AN LSTM WORD LANGUAGE MODEL” shows speed-up with Xeon e5 2620 v4 system;

Env-set and result as folows:
torch._ config _.parallel_info()

ATen/Parallel:
        at::get_num_threads() : 1
        at::get_num_interop_threads() : 16
OpenMP 201511 (a.k.a. OpenMP 4.5)
        omp_get_max_threads() : 1
Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
        mkl_get_max_threads() : 1
Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
std::thread::hardware_concurrency() : 32
Environment variables:
        OMP_NUM_THREADS : [not set]
        MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP

torch._config_.show

PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

Result

loss: 5.167
elapsed time (seconds): 225.1
loss: 5.168
elapsed time (seconds): 148.0

1 Like

We haven’t benchmarked haswell but as long as it has avx2 we should see similar speedups.

I also tried dynamic quatization on transformer word language model. (It is in examples/word_language_model at main · pytorch/examples · GitHub)
And in this case, it shows some speed-up with Xeon e5 2620 v4 system too;;;

Env-set and result as folows:
torch._ config _.parallel_info()

ATen/Parallel:
        at::get_num_threads() : 1
        at::get_num_interop_threads() : 16
OpenMP 201511 (a.k.a. OpenMP 4.5)
        omp_get_max_threads() : 1
Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
        mkl_get_max_threads() : 1
Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
std::thread::hardware_concurrency() : 32
Environment variables:
        OMP_NUM_THREADS : [not set]
        MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP

torch.config.show

PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

Result

1 Like

I meet the same problem on Xeon e5 2620 v4. Thanks for last two attempts.

is the problem resolved? cc @danchu @dskhudia

Hello,

I am also trying to reproduce the results of the dynamic quantization example provided at (beta) Dynamic Quantization on BERT — PyTorch Tutorials 2.1.1+cu121 documentation

from transformers import BertForSequenceClassification, BertTokenizerFast
import torch
import os

tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased-finetuned-mrpc")
model = BertForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")

batch_size = 32
sequence_length = 128
vocab_size = len(tokenizer.vocab)

token_ids = torch.randint(vocab_size, (batch_size, sequence_length))

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

I then compare the memory footprint and performance of the baseline model and its quantized counterpart:

print_size_of_model(model)
%timeit model(token_ids)

Size (MB): 433.3
2.16 s ± 30.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

print_size_of_model(quantized_model)
%timeit quantized_model(token_ids)

Size (MB): 176.8
2.08 s ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Quantization seems to be working fine based on the significantly reduced memory footprint, but I cannot reproduce the execution speed benefits highlighted in the example (close to 50%)

I am running on CentOS Linux release 7.9.2009 with a Intel(R) Xeon(R) CPU E5-2690 v3.
I am running transformers v4.3.2 with torch v1.7.0. I installed the CUDA 11.0 version of Pytorch but I am not moving the tensors or the models to the GPU in the example provided.

I ran the same example on a notebook with a i5-8350U CPU and could observe a significant speed-up. Could this be linked to an issue with Haswell CPUs?

Detailed configuration:

ATen/Parallel:
	at::get_num_threads() : 6
	at::get_num_interop_threads() : 3
OpenMP 201511 (a.k.a. OpenMP 4.5)
	omp_get_max_threads() : 6
Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
	mkl_get_max_threads() : 6
Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
std::thread::hardware_concurrency() : 6
Environment variables:
	OMP_NUM_THREADS : [not set]
	MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP
PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.0
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80
  - CuDNN 8.0.4
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

Hi, have you solved this problem? I’d like to accelerate a model by static quantization on Xeon E5 2680 v4 but meet the same thing. Is it really the problem of cpu microarchitecture?