I follow the official tutorial to build custom CUDA extensions. And I would like to use the function at::cuda::blas::gemm<float>()
to do the matrix product, which is defined in #include <ATen/cuda/CUDABlas.h>
. But the g++ compiler seems to fail to link this function according to current configurations. Could anyone give me some help?
Steps to reproduce the behavior with a toy example:
The cpp file
:#include <ATen/cuda/CUDABlas.h> #include <torch/extension.h> torch::Tensor gemm(torch::Tensor A, torch::Tensor B) { int64_t N = A.size(0); torch::Tensor C = torch::zeros_like(A); at::cuda::blas::gemm<float>('n', 'n', // transa, transb N, N, N, // M, N, K 1.0f, A.data_ptr<float>(), N, // alpha, a, lda B.data_ptr<float>(), N, 0.0f, // b, ldb, beta C.data_ptr<float>(), N); // c, ldc } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gemm", &gemm, "gemm"); }
The setup file
:from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name='pytorch_gemm_gpu', ext_modules=[ CUDAExtension('pytorch_gemm_gpu', ['pytorch_gemm_gpu.cpp']), ], cmdclass={ 'build_ext': BuildExtension })
Run the following command to install the package. This procedure finished without errors.
python setup.py install
Run the following command to import the package.
python -c "import torch; import pytorch_gemm_gpu"
The error message is :
pytorch_gemm_gpu.cpython-37m-x86_64-linux-gnu.so:undefined symbol: _ZN2at4cuda4blas4gemmIfEEvcclllT_PKS3_lS5_lS3_PS3_l
Expected behavior
I expect the command python -c "import torch; import pytorch_gemm_gpu"
succeeds without errors.
And I did find the symbol _ZN2at4cuda4blas4gemmIfEEvcclllT_PKS3_lS5_lS3_PS3_l
in the libtorch_cuda.so
. I do not know why this symbol can not be found at runtime.
PyTorch version: 1.6.0
Is debug build: False
CUDA used to build PyTorch: 10.1
OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: version 3.14.5
Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.168
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
Nvidia driver version: 418.67
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] pytorch-gemm-gpu==0.0.0
[pip3] torch==1.6.0
[pip3] torchvision==0.7.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] mkl 2020.1 217
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.1.0 py37h23d657b_0
[conda] mkl_random 1.1.1 py37h0da4684_0 conda-forge
[conda] numpy 1.19.1 py37hbc911f0_0
[conda] numpy-base 1.19.1 py37hfa32c7d_0
[conda] pytorch 1.6.0 py3.7_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] pytorch-gemm-gpu 0.0.0 pypi_0 pypi
[conda] torchvision 0.7.0 py37_cu101 pytorch