Dynamic Parallelism Extension Error

Hello, I have been trying to build a CUDA extension where a parent kernel calls a child kernel (this is meant to be possible according to dynamic parallelism).

I have a dyn_par_cuda_kernel.cu file:

#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>

namespace {
    __global__ void child_kernel(){
        printf("Running child kernel \n" );
    }
    __global__ void parent_kernel() {
        child_kernel<<<1, 1>>>();
    }
}

int test_dynamic_paralellism_cuda(){
    parent_kernel<<<1, 1>>>();
    return 0;
}

A dyn_par_cuda.cpp file:

#include <torch/extension.h>

int test_dynamic_paralellism_cuda(void);

int test_dynamic_paralellism(void){
    return test_dynamic_paralellism_cuda();
}

// Binder
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("test_dynamic_paralellism", &test_dynamic_paralellism, "test_dynamic_paralellism");
}

And finally a setup.py file:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='dyn_par_test',
    ext_modules=[
        CUDAExtension('dyn_par_test', [
            'dyn_par_cuda.cpp',
            'dyn_par_cuda_kernel.cu',
        ],
        extra_compile_args={'cxx': ['-Wall'], 'nvcc': ['-rdc=true', '-lcudadevrt']}),
    ],
    cmdclass={
        'build_ext': BuildExtension
    })

I run python setup.py install and I get the following error

running install
running bdist_egg
running egg_info
writing dyn_par_test.egg-info\PKG-INFO
writing dependency_links to dyn_par_test.egg-info\dependency_links.txt
writing top-level names to dyn_par_test.egg-info\top_level.txt
reading manifest file 'dyn_par_test.egg-info\SOURCES.txt'
writing manifest file 'dyn_par_test.egg-info\SOURCES.txt'
installing library code to build\bdist.win-amd64\egg
running install_lib
running build_ext
C:\Users\my_user\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\cpp_extension.py:274: UserWarning: Error checking compiler version for cl: [WinError 2] The system cannot find the file specified
  warnings.warn('Error checking compiler version for {}: {}'.format(compiler, error))
building 'dyn_par_test' extension
Emitting ninja build file C:\Users\my_user\PycharmProjects\torch_extensions\my_cuda\build\temp.win-amd64-3.7\Release\build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.28.29333\bin\HostX86\x64\link.exe /nologo /INCREMENTAL:NO /LTCG /DLL /MANIFEST:EMBED,ID=2 /MANIFESTUAC:NO /LIBPATH:C:\Users\my_user\Anaconda3\envs\pytorch\lib\site-packages\torch\lib "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\lib/x64" /LIBPATH:C:\Users\my_user\Anaconda3\envs\pytorch\libs /LIBPATH:C:\Users\my_user\Anaconda3\envs\pytorch\PCbuild\amd64 "/LIBPATH:C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.28.29333\ATLMFC\lib\x64" "/LIBPATH:C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.28.29333\lib\x64" "/LIBPATH:C:\Program Files (x86)\Windows Kits\NETFXSDK\4.6.1\lib\um\x64" "/LIBPATH:C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\ucrt\x64" "/LIBPATH:C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\um\x64" c10.lib torch.lib torch_cpu.lib torch_python.lib cudart.lib c10_cuda.lib torch_cuda.lib /EXPORT:PyInit_dyn_par_test C:\Users\my_user\PycharmProjects\torch_extensions\my_cuda\build\temp.win-amd64-3.7\Release\dyn_par_cuda.obj C:\Users\my_user\PycharmProjects\torch_extensions\my_cuda\build\temp.win-amd64-3.7\Release\dyn_par_cuda_kernel.obj /OUT:build\lib.win-amd64-3.7\dyn_par_test.cp37-win_amd64.pyd /IMPLIB:C:\Users\my_user\PycharmProjects\torch_extensions\my_cuda\build\temp.win-amd64-3.7\Release\dyn_par_test.cp37-win_amd64.lib
   Creating library C:\Users\my_user\PycharmProjects\torch_extensions\my_cuda\build\temp.win-amd64-3.7\Release\dyn_par_test.cp37-win_amd64.lib and object C:\Users\my_user\PycharmProjects\torch_extensions\my_cuda\build\temp.win-amd64-3.7\Release\dyn_par_test.cp37-win_amd64.exp
dyn_par_cuda_kernel.obj : error LNK2001: unresolved external symbol __cudaRegisterLinkedBinary_55_tmpxft_000019e0_00000000_11_dyn_par_cuda_kernel_cpp1_ii_21ca234b
build\lib.win-amd64-3.7\dyn_par_test.cp37-win_amd64.pyd : fatal error LNK1120: 1 unresolved externals
error: command 'C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Community\\VC\\Tools\\MSVC\\14.28.29333\\bin\\HostX86\\x64\\link.exe' failed with exit status 1120

A similar error was reported in Issue and in Issue but so far there is no solution.

I have tried different variations of extra_compile_args={'cxx': ['-Wall'], 'nvcc': ['-rdc=true', '-lcudadevrt']}) such us specifying -arch=sm_61 with no improvement. The error persists even if I remove the child kernel call within the parent kernel which points at the issue being related to specifying the arguments '-rdc=true', '-lcudadevrt' (i.e. separate CUDA compilation and linking as required).

Thank you for your help!

2 Likes