How can I use the function at::cuda::blas::gemm<float>()?

I follow the official tutorial to build custom CUDA extensions. And I would like to use the function at::cuda::blas::gemm<float>() to do the matrix product, which is defined in #include <ATen/cuda/CUDABlas.h>. But the g++ compiler seems to fail to link this function according to current configurations. Could anyone give me some help?

Steps to reproduce the behavior with a toy example:

  1. The cpp file pytorch_gemm_gpu.cpp:

     #include <ATen/cuda/CUDABlas.h>
     #include <torch/extension.h>
     torch::Tensor gemm(torch::Tensor A, torch::Tensor B) {
       int64_t N = A.size(0);
       torch::Tensor C = torch::zeros_like(A);
       at::cuda::blas::gemm<float>('n', 'n',                       // transa, transb
                                    N, N, N,                       // M, N, K
                                    1.0f, A.data_ptr<float>(), N,  // alpha, a, lda
                                    B.data_ptr<float>(), N, 0.0f,  // b, ldb, beta
                                    C.data_ptr<float>(), N);       // c, ldc
       m.def("gemm", &gemm, "gemm"); 

    The setup file

     from setuptools import setup
     from torch.utils.cpp_extension import BuildExtension, CUDAExtension
             CUDAExtension('pytorch_gemm_gpu', ['pytorch_gemm_gpu.cpp']),
             'build_ext': BuildExtension
  2. Run the following command to install the package. This procedure finished without errors.

    python install
  3. Run the following command to import the package.

    python -c "import torch; import pytorch_gemm_gpu"

    The error message is : symbol: _ZN2at4cuda4blas4gemmIfEEvcclllT_PKS3_lS5_lS3_PS3_l

Expected behavior

I expect the command python -c "import torch; import pytorch_gemm_gpu" succeeds without errors.
And I did find the symbol _ZN2at4cuda4blas4gemmIfEEvcclllT_PKS3_lS5_lS3_PS3_l in the I do not know why this symbol can not be found at runtime.


PyTorch version: 1.6.0
Is debug build: False
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: version 3.14.5

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.168
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 418.67
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] pytorch-gemm-gpu==0.0.0
[pip3] torch==1.6.0
[pip3] torchvision==0.7.0
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] mkl                       2020.1                      217
[conda] mkl-service               2.3.0            py37he904b0f_0
[conda] mkl_fft                   1.1.0            py37h23d657b_0
[conda] mkl_random                1.1.1            py37h0da4684_0    conda-forge
[conda] numpy                     1.19.1           py37hbc911f0_0
[conda] numpy-base                1.19.1           py37hfa32c7d_0
[conda] pytorch                   1.6.0           py3.7_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] pytorch-gemm-gpu          0.0.0                    pypi_0    pypi
[conda] torchvision               0.7.0                py37_cu101    pytorch
1 Like


In the extensions, you should not try to access low level functions like this one directly.
You can use at::mm() or at::bmm() directly with Tensor arguments to get that result (also these function will have all the nice autograd support while low level ones won’t).

Thank you so much for your reply!

Hello, for my use case it is absolutely nessecary to use this “low level” function at::cuda::blas::gemm<scalar_t> and I am getting the same linker error like OP. Otherwise I suppose I could also directly link against cuBLAS and take the original function from there, but it would be nicer to have access to this wrapper since it elegantly handles different floating types.

I have made sure that PyTorch and the extensions are built with the same compiler and ABI flag (which seems to be another source for this linker error)