Custom CUDA Extension with Dynamic Parallelism

Hi,

I am trying to create a custom module using CUDA with Dynamic Parallelism. I have followed the basic module extension from Pytorch Documentation.

My newmodule.cpp is as follows

#include <torch/extension.h>

cudaError_t newmodule_cuda(...);

torch::Tensor newmodule(
    const torch::Tensor& self,
    const torch::Tensor& weight,
    const torch::Tensor& bias,
    int64_t pad)
{
	// do some operations
        newmodule_cuda(...);
        // do some operations
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &newmodule, "newmodule CUDA");
}

My newmodule_cuda.cu is as follows.

#include <cuda.h>
#include <cuda_runtime.h>
#include <driver_functions.h>
#include <torch/extension.h>
#include <ATen/ATen.h>

template <typename scalar_t>
__global__ 
void child(.....)
{
	// do operation with tensor
}

template <typename scalar_t>
__global__
void parent(...)
{
	// do parent operation
	
	// launching child
		
	// Convolve
	child<<<numBlock, numThread>>>(...);
	child<<<numBlock, numThread>>>(...);
	// do parent operation	
}

cudaError_t newmodule_cuda(...)
{
	torch::Device deviceCPU(torch::kCPU);
    torch::Device deviceGPU(torch::kCPU);
	if (torch::cuda::is_available())
    {
        std::cout << "CUDA is available! Run on GPU." << std::endl;
		// do preparation
	    AT_DISPATCH_FLOATING_TYPES(self.type(), "newmodule_cuda", ([&] {
			parent<scalar_t><<<numBlock,numThread>>>(...);	
		}));
        cudaDeviceSynchronize();
    }
	else
	{
		// CPU code
		
	}
	return cudaSuccess;
}

My setup.py is as follows.

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
 
setup(
    name='newmodule',
    ext_modules=[
        CUDAExtension('convtbcglu_cuda', [
            'newmodule.cpp',
            'newmodule_cuda.cu',
        ],
		extra_compile_args={'cxx': ['-Wall'], 'nvcc': ['-arch=sm_70']})                  
    ],
    cmdclass={
        'build_ext': BuildExtension
    })

With this setup, I get an error when executing python setup.py install

error: kernel launch from __device__ or __global__ functions requires separate compilation mode

Then, I add ‘-rdc=true’, ‘-lcudadevrt’ to my setup.py as follows.

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
 
setup(
    name='newmodule',
    ext_modules=[
        CUDAExtension('convtbcglu_cuda', [
            'newmodule.cpp',
            'convtbcglu_cuda.cu',
        ],
		extra_compile_args={'cxx': ['-Wall'], 'nvcc': ['-arch=sm_70', '-rdc=true', '-lcudadevrt']})                   
    ],
    cmdclass={
        'build_ext': BuildExtension
    })

It is compiled successfuly, but when I am trying to run the newmodule, it returns error.

undefined symbol: __cudaRegisterLinkedBinary_50

I have read this Issue but still cannot figure out the solution.

Any help would be appreciated.

Thank you.

1 Like

try changing
global
void child(…)
to

device
void child(…)
Let me know if that helps

Reproduced the error with.

Python: 3.7.5
PyTorch: 1.3.1
nvcc: 10.0.130
c++ is gcc 7.4.0
os: Ubuntu

I have no idea what is causing this.
Dynamic parallelism is vital for my current task, will be following this post closely.

@hibagus: Did you try to compile things by hand?

Reproduced error also with pytorch 1.2.0.

I discovered

nvcc -v -> 10.0.130

import torch
torch.version.cuda # gives 10.1.243

Downgrading to pytorch 1.2.0 made the two versions match, but it did not resolve the issue.

This caused following compilation error:

error: a device function call cannot be configured

Is there any reason you expected this to work?

I got the same issue here, have you solved it?

I am also facing the exact same issue. Reading from this stackoverflow question, it suggests we need an intermediate step to do device linking. How can this be achieved here?
@albanD @ptrblck any suggestion on how to build dynamic parallelism ?

Currently, the Pytorch build script doesn’t support it yet. You need to add an explicit step to do a device link with the –dlink option when you are using ‘-rdc=true’. For example:

nvcc –arch=sm_35 –dc a.cu b.cu
nvcc –arch=sm_35 –dlink a.o b.o –o dlink.o
g++ a.o b.o dlink.o x.cpp –lcudart

I have changed my build script and it works well.

@Vincent-bin can you please tell how did you manage?
Inside the build.py or setup.py script?
Where exactly did you add these instructions?

1 Like