Matmul from OneDNN is 10x slower than MKL_BLAS_Sgemm?

Summary

I have downloaded pre-trained LLM: deepseek-R1-1.5B, separately running on OneDNN and MKL_BLAS.
And get 10x slower:

Image

            model, tokenizer = load_model_for_profiling(
                args.model,
                torch_dtype,
                random_weights=args.random_weights
            )
            
# Apply MKLDNN/OneDNN optimization
            if args.use_mkldnn:
                model = mkldnn_utils.to_mkldnn(model)

# Prepare batch inputs
            if args.batch_prompts:
    # Use the provided list of prompts
                prompts = args.batch_prompts
                if len(prompts) < args.batch_size:
        # If there are not enough prompts, repeat the last prompt
                    prompts.extend([prompts[-1]] * (args.batch_size - len(prompts)))
                prompts = prompts[:args.batch_size]
            else:
    # Repeat the default prompt
                prompts = [args.prompt] * args.batch_size
            
            print(f"Preparing batch of {args.batch_size} prompts...")
            inputs = tokenizer(prompts, return_tensors="pt", padding=True)
            
# Calculate number of prompt tokens
            prompt_tokens = inputs["input_ids"].shape[1]
            performance_data["prompt_tokens"] = prompt_tokens
            
            inputs = {key: val.to(device) for key, val in inputs.items()}
            def inference_fn():
                with torch.no_grad():
                    for _ in range(args.iterations):
                        if args.use_mkldnn and args.mkldnn_verbose:
                            with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON):
                                outputs = model(**inputs)
                        else:
                            outputs = model(**inputs)

I have checked the onednn_verbose log info, and realized that the cost of aten::linear is mainly about the reorder:

onednn_verbose,primitive,exec,cpu,inner_product,brgemm:avx512_core,forward_training,src_f32:a:blocked:ab::f0 wei_f32:a:blocked:AB16b64a::f0 bia_f32:a:blocked:a::f0 dst_f32:a:blocked:ab::f0,attr-scratchpad:user,,mb1ic1536oc8960,0.596924
onednn_verbose,primitive,exec,cpu,reorder,simple:any,undef,src_f32::blocked:abc::f0 dst_f32::blocked:abc::f0,attr-scratchpad:user,,1x1x8960,4.28809
onednn_verbose,primitive,exec,cpu,reorder,simple:any,undef,src_f32::blocked:abc::f0 dst_f32::blocked:abc::f0,attr-scratchpad:user,,1x1x1536,0.0109863
onednn_verbose,primitive,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:ab::f0 dst_f32::blocked:AB16b64a::f0,attr-scratchpad:user,,8960x1536,4.32202
onednn_verbose,primitive,exec,cpu,inner_product,brgemm:avx512_core,forward_training,src_f32:a:blocked:ab::f0 wei_f32:a:blocked:AB16b64a::f0 bia_f32:a:blocked:a::f0 dst_f32:a:blocked:ab::f0,attr-scratchpad:user,,mb1ic1536oc8960,0.611816
onednn_verbose,primitive,exec,cpu,reorder,simple:any,undef,src_f32::blocked:abc::f0 dst_f32::blocked:abc::f0,attr-scratchpad:user,,1x1x8960,0.059082
onednn_verbose,primitive,exec,cpu,reorder,simple:any,undef,src_f32::blocked:abc::f0 dst_f32::blocked:abc::f0,attr-scratchpad:user,,1x1x8960,0.013916
onednn_verbose,primitive,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:ab::f0 dst_f32::blocked:AB16b64a::f0,attr-scratchpad:user,,1536x8960,4.17603
onednn_verbose,primitive,exec,cpu,inner_product,brgemm:avx512_core,forward_training,src_f32:a:blocked:ab::f0 wei_f32:a:blocked:AB16b64a::f0 bia_f32:a:blocked:a::f0 dst_f32:a:blocked:ab::f0,attr-scratchpad:user,,mb1ic8960oc1536,0.780029

My question is, is there a way to directly convert a pretrained model’s weights downloaded from PyTorch into onednn format, so that future linear operations don’t need to perform reordering?

here is my pytorch enviroment:
:root@server1:/workspace# python -c “import torch; print(f’Pytorch version : {torch.version}‘); print(*torch.config.show().split(’\n’), sep=‘\n’)”
Pytorch version : 2.5.1+cu124
PyTorch built with:

  • GCC 9.3
  • C++ Version: 201703
  • Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v3.5.3 (Git Hash 66f0cb9)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • LAPACK is enabled (usually provided by MKL)
  • NNPACK is enabled
  • CPU capability usage: AVX512
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.4, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.5.1, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

root@server1:/workspace# lscpu
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Vendor ID: GenuineIntel
BIOS Vendor ID: Intel
Model name: Intel(R) Xeon(R) Silver 4214 CPU @ 2.20GHz
BIOS Model name: Intel(R) Xeon(R) Silver 4214 CPU @ 2.20GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 2
Stepping: 7
BogoMIPS: 4400.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc a
rt arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm
pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 cdp_l3 invpcid_single intel_ppin intel
_pt ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f
avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln
pts pku ospke avx512_vnni md_clear spec_ctrl intel_stibp flush_l1d arch_capabilities
Virtualization features:
Virtualization: VT-x
Caches (sum of all):
L1d: 768 KiB (24 instances)
L1i: 768 KiB (24 instances)
L2: 24 MiB (24 instances)
L3: 33 MiB (2 instances)
NUMA:
NUMA node(s): 2
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47
Vulnerabilities:
Itlb multihit: KVM: Mitigation: Split huge pages
L1tf: Not affected
Mds: Not affected
Meltdown: Not affected
Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Spectre v1: Mitigation; Load fences, usercopy/swapgs barriers and __user pointer sanitization
Spectre v2: Mitigation; Full retpoline, IBPB
Srbds: Not affected
Tsx async abort: Mitigation; Clear CPU buffers; SMT vulnerable