Hi,
I want to include some C++/CUDA extension into some project. However, when I run my project, I have the following error:
ModuleNotFoundError: No module named 'HT.transformers.causal_product.causal_product_cpu'
My C++ files are located in folder HT/transformers/causal_product
and are named causal_product_cpu.cpp
and causal_product_cuda.cu
. Those files come from another repository so I assume that my problem doesn’t come from here.
I built the extensions running python setup.py build_ext
. Here is the setup.py
file:
from functools import lru_cache
from setuptools import find_packages, setup
from subprocess import DEVNULL, call
from torch.utils.cpp_extension import BuildExtension, CppExtension
@lru_cache(None)
def cuda_toolkit_available():
try:
call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL)
return True
except FileNotFoundError:
return False
def get_extensions():
extensions = [
CppExtension(
"HT.transformers.causal_product.causal_product_cpu",
sources=[
"HT/transformers/causal_product/causal_product_cpu.cpp"
],
extra_compile_args=["-fopenmp", "-ffast-math"]
)
]
if cuda_toolkit_available():
from torch.utils.cpp_extension import CUDAExtension
extensions += [
CUDAExtension(
"HT.transformers.causal_product.causal_product_cuda",
sources=[
"HT/transformers/causal_product/causal_product_cuda.cu"
],
extra_compile_args=["-arch=compute_50"]
)
]
print(extensions)
return extensions
if __name__ == '__main__':
setup(...,
license='',
packages=find_packages(exclude=['models']),
zip_safe=False,
ext_modules=get_extensions(),
cmdclass={'build_ext': BuildExtension},
install_requires=['torch'])
After I build extensions, a build/
directory is created. In particular, subfolder build/lib.linux-x86_64-3.6/HT/transformers/causal_product/
contains some .so
files and subfolder build/temp.linux-x86_64-3.6/HT/transformers/causal_product/
contains some .o
files so I assume that they have been correctly been compiled.
It is the first time I try to include C++ code in a PyTorch project and I don’t know anything about c++ compilation, so I might have missed something obvious…
Does anyone have an idea how to solve my problem?
Thanks in advance!
Alain