Error when building a Custom C++ and CUDA Extension

I’m trying to implement a 3D NMS function in CUDA following the example outlined in Custom C++ and CUDA Extensions — PyTorch Tutorials 2.1.1+cu121 documentation.

My CUDA and C++ experience is limited and thus it may be that there is, as I hope, a pretty stupid error somewhere that one of you guys could spot. Anyway my code is made by four scripts:

  1. nms_kernel.cu which contains the cuda functions;
  2. nms_cuda.cpp which contains the c++ function;
  3. setup.py

Hereafter I’m posting the four scripts:

nms_kernel.cu

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__device__ inline float devIoU(float const * const a, float const * const b) {
  float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
  float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
  float front = fmaxf(a[4], b[4]), back = fminf(a[5], b[5]);

  float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f), depth = fmaxf(back - front + 1, 0.f);
  float interS = width * height * depth;
  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1) * (a[5] - a[4] + 1);
  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1) * (b[5] - b[4] + 1);
  //printf("IoU 3D %f \n", interS / (Sa + Sb - interS));

  return interS / (Sa + Sb - interS);
}

__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
                           const float *dev_boxes, unsigned long long *dev_mask) {
  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  // if (row_start > col_start) return;

  const int row_size =
        fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
  const int col_size =
        fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

  __shared__ float block_boxes[threadsPerBlock * 7];
  if (threadIdx.x < col_size) {
    block_boxes[threadIdx.x * 7 + 0] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 0];
    block_boxes[threadIdx.x * 7 + 1] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 1];
    block_boxes[threadIdx.x * 7 + 2] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 2];
    block_boxes[threadIdx.x * 7 + 3] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 3];
    block_boxes[threadIdx.x * 7 + 4] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 4];
    block_boxes[threadIdx.x * 7 + 5] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 5];
    block_boxes[threadIdx.x * 7 + 6] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 6];
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
    const float *cur_box = dev_boxes + cur_box_idx * 7;
    int i = 0;
    unsigned long long t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
    }
    for (i = start; i < col_size; i++) {
      if (devIoU(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
        t |= 1ULL << i;
      }
    }
    const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
    dev_mask[cur_box_idx * col_blocks + col_start] = t;
  }
}


void _nms(int boxes_num, float * boxes_dev,
          unsigned long long * mask_dev, float nms_overlap_thresh) {


  dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
              DIVUP(boxes_num, threadsPerBlock));
  dim3 threads(threadsPerBlock);
  nms_kernel<<<blocks, threads>>>(boxes_num,
                                  nms_overlap_thresh,
                                  boxes_dev,
                                  mask_dev);
}


void gpu_nms(at::Tensor keep, at::Tensor boxes, at::Tensor num_out,  float nms_overlap_thresh){
    // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
    // params keep: (N)

    int boxes_num = boxes.size(0);
    const float * boxes_data = boxes.data<float>();
    long * keep_data = keep.data<long>();

    const int col_blocks = DIVUP(boxes_num, threadsPerBlock);

    unsigned long long mask_data[boxes_num * col_blocks]
    cudaMalloc((void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long));
    _nms(boxes_num, boxes_data, mask_data, nms_overlap_thresh);

    unsigned long long mask_cpu[boxes_num * col_blocks];
    //unsigned long long * mask_cpu = new unsigned long long [boxes_num * col_blocks];
    std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);

    //printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
    cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long),
                           cudaMemcpyDeviceToHost);

    cudaFree(mask_data);

    unsigned long long remv_cpu[col_blocks];
    cudaMemset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));

    int num_to_keep = 0;
    int i, j;
    for (int i = 0; i < boxes_num; i++){
        int nblock = i / threadsPerBlock;
        int inblock = i % threadsPerBlock;

        if (!(remv_cpu[nblock] & (1ULL << inblock))){
            keep_data[num_to_keep++] = i;
            unsigned long long *p = &mask_cpu[0] + i * col_blocks;
            for (int j = nblock; j < col_blocks; j++){
                remv_cpu[j] |= p[j];
            }
        }
    }
    if ( cudaSuccess != cudaGetLastError() ) printf( "Error!\n" );

    return num_to_keep;
}

nms_cuda.cpp

#include <torch/extension.h>

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)


int gpu_nms(at::Tensor keep, at::Tensor boxes, at::Tensor num_out,  float nms_overlap_thresh)

inst gpu_nms_forward(at::Tensor keep, at::Tensor boxes, at::Tensor num_out,  float nms_overlap_thresh){
  CHECK_INPUT(keep)
  CHECK_INPUT(boxes)
  return gpu_nms(keep, boxes, num_out, nms_overlap_threshold)
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
  m.def("forward", &gpu_nms_forward, "gpu_nms ('CUDA')");
}

setup.py

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
    name='nms_3d',
    ext_modules=[
        CUDAExtension('nms_3d', [
            'nms_cuda.cpp',
            'nms_kernel.cu'

        ])],
    cmdclass={'build_ext': BuildExtension})

System information: I’m running this is a conda environment with the following relevant packages and libraries:

  1. pytorch=1.9.1 (stable);
  2. CUDA=10.2;
  3. cudatoolkit=10.2.89;
  4. cudatoolkit-dev=10.1.243;
  5. python=3.8.8;
  6. cudnn=7.6.5;
  7. Pytorch C++ API: libtorch-shared-with-deps-1.9.1+cu102
  8. Ubuntu 20.04.3 LTS

When I try to do

python setup.py install

I get the following error:

running install
running bdist_egg
running egg_info
writing nms_3d.egg-info/PKG-INFO
writing dependency_links to nms_3d.egg-info/dependency_links.txt
writing top-level names to nms_3d.egg-info/top_level.txt
reading manifest file 'nms_3d.egg-info/SOURCES.txt'
writing manifest file 'nms_3d.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'nms_3d' extension
Emitting ninja build file /home/michele/Documents/GitHub/MicheleDelliVeneri/Thesis/Code/nms_3D/build/temp.linux-x86_64-3.8/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: error: '/home/michele/Documents/GitHub/MicheleDelliVeneri/Thesis/Code/nms_3D/nms_kernel.cu', needed by '/home/michele/Documents/GitHub/MicheleDelliVeneri/Thesis/Code/nms_3D/build/temp.linux-x86_64-3.8/nms_kernel.o', missing and no known rule to make it
Traceback (most recent call last):
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1666, in _run_ninja_build
    subprocess.run(
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "setup.py", line 5, in <module>
    setup(
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/__init__.py", line 153, in setup
    return distutils.core.setup(**attrs)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/core.py", line 148, in setup
    dist.run_commands()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/dist.py", line 966, in run_commands
    self.run_command(cmd)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/dist.py", line 985, in run_command
    cmd_obj.run()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/command/install.py", line 67, in run
    self.do_egg_install()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/command/install.py", line 109, in do_egg_install
    self.run_command('bdist_egg')
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/cmd.py", line 313, in run_command
    self.distribution.run_command(command)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/dist.py", line 985, in run_command
    cmd_obj.run()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/command/bdist_egg.py", line 164, in run
    cmd = self.call_command('install_lib', warn_dir=0)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/command/bdist_egg.py", line 150, in call_command
    self.run_command(cmdname)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/cmd.py", line 313, in run_command
    self.distribution.run_command(command)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/dist.py", line 985, in run_command
    cmd_obj.run()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/command/install_lib.py", line 11, in run
    self.build()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/command/install_lib.py", line 107, in build
    self.run_command('build_ext')
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/cmd.py", line 313, in run_command
    self.distribution.run_command(command)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/dist.py", line 985, in run_command
    cmd_obj.run()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/command/build_ext.py", line 79, in run
    _build_ext.run(self)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/Cython/Distutils/old_build_ext.py", line 186, in run
    _build_ext.build_ext.run(self)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/command/build_ext.py", line 340, in run
    self.build_extensions()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 709, in build_extensions
    build_ext.build_extensions(self)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/Cython/Distutils/old_build_ext.py", line 195, in build_extensions
    _build_ext.build_ext.build_extensions(self)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/command/build_ext.py", line 449, in build_extensions
    self._build_extensions_serial()
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/command/build_ext.py", line 474, in _build_extensions_serial
    self.build_extension(ext)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/setuptools/command/build_ext.py", line 196, in build_extension
    _build_ext.build_extension(self, ext)
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/distutils/command/build_ext.py", line 528, in build_extension
    objects = self.compiler.compile(sources,
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 530, in unix_wrap_ninja_compile
    _write_ninja_file_and_compile_objects(
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1355, in _write_ninja_file_and_compile_objects
    _run_ninja_build(
  File "/home/michele/anaconda3/envs/with_cudatoolkit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1682, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error compiling objects for extension

From the error message, I cannot seem to understand which may be the problem, any help would be greatly appreciated.

Thank you,
Michele

manually run ninja from console on this:

Emitting ninja build file /home/michele/Documents/GitHub/MicheleDelliVeneri/Thesis/Code/nms_3D/build/temp.linux-x86_64-3.8/build.ninja

and you should see c++ error(s)

First of all thank you for the reply, I tried what you suggested and by going in the mentioned folder and running the command:

ninja

I get the same output message unfortunately:

ninja: error: '/home/michele/Documents/GitHub/MicheleDelliVeneri/Thesis/Code/nms_3D/nms_kernel.cu', needed by '/home/michele/Documents/GitHub/MicheleDelliVeneri/Thesis/Code/nms_3D/build/temp.linux-x86_64-3.8/nms_kernel.o', missing and no known rule to make it

am I misunderstanding what you meant by manually run ninja?

If you view ninja.build, it has relevant compiler commands (in your case nms_kernel.cu build is failing). I think your build file doesn’t contain absolute filenames, then try either running from directory with sources “ninja -f <full_path_to_ninja.build>” or copy compiler command from build file.

Actually, I think there is some setup.py flag for verbose output, but I can’t tell this specifically right now.

1 Like

Thank you Alex for your responses, it gave me an idea on how to debug it and at the end I managed to do it. Given that maybe it could be instructional for someone else, here it is how it went down.

Instead of going in the setup.py direction, I made a CMakeLists.txt file which looks like this:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
# here I decide the name of the library
project(nms_cmake LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 14)

# I make sure to point to the correct libraries
set(CMAKE_PREFIX_PATH /home/michele/pytorch_cpp/libtorch-shared-with-deps-1.9.1+cu102/libtorch)

# I make sure that it has all the python and torch directories available
set(TORCH_INCLUDE_DIRS
   /home/michele/anaconda3/envs/with_cudatoolkit/include/
   /home/michele/anaconda3/envs/with_cudatoolkit/include/torch/csrc/api/include)

# I make sure that both Python and Torch are found, also all the py-dev stuff
find_package(Torch REQUIRED)
find_package(Python REQUIRED COMPONENTS Development)

# add all the sources to the library
# List all your code files here
add_library(nms_cmake SHARED
  main.cu
)

target_link_libraries(nms_cmake "${TORCH_LIBRARIES}")

# To really make sure that it is using the correct language and gcc to make stuff.
target_compile_options(nms_cmake PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-ccbin g++-7>)

main.cu at this stage was just a combination of the previous nms_kernel.cu and nms_cuda.cpp scripts and full of unknown syntactical errors.
After creating the build fonder and launching:

 cmake -DCMAKE_PREFIX_PATH=$PWD/home/michele/pytorch_cpp/libtorch-shared-with-deps-1.9.1+cu102/libtorch ..

and

make

I could see may errors in the file and after some debugging it ended up like this:

#include <c10/cuda/CUDAException.h>
#include <torch/torch.h>
#include <torch/all.h>
#include <torch/python.h>
#include <torch/extension.h>

using namespace at;

#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;


__device__ inline float devIoU(float const * const a, float const * const b) {
  float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
  float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
  float front = fmaxf(a[4], b[4]), back = fminf(a[5], b[5]);

  float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f), depth = fmaxf(back - front + 1, 0.f);
  float interS = width * height * depth;
  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1) * (a[5] - a[4] + 1);
  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1) * (b[5] - b[4] + 1);
  //printf("IoU 3D %f \n", interS / (Sa + Sb - interS));

  return interS / (Sa + Sb - interS);
}

__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
                           const float *dev_boxes, unsigned long long *dev_mask) {
  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  // if (row_start > col_start) return;

  const int row_size =
        fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
  const int col_size =
        fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

  __shared__ float block_boxes[threadsPerBlock * 7];
  if (threadIdx.x < col_size) {
    block_boxes[threadIdx.x * 7 + 0] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 0];
    block_boxes[threadIdx.x * 7 + 1] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 1];
    block_boxes[threadIdx.x * 7 + 2] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 2];
    block_boxes[threadIdx.x * 7 + 3] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 3];
    block_boxes[threadIdx.x * 7 + 4] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 4];
    block_boxes[threadIdx.x * 7 + 5] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 5];
    block_boxes[threadIdx.x * 7 + 6] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 6];
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
    const float *cur_box = dev_boxes + cur_box_idx * 7;
    int i = 0;
    unsigned long long t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
    }
    for (i = start; i < col_size; i++) {
      if (devIoU(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
        t |= 1ULL << i;
      }
    }
    const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
    dev_mask[cur_box_idx * col_blocks + col_start] = t;
  }
}


void _nms(int boxes_num, float * boxes_dev,
          unsigned long long * mask_dev, float nms_overlap_thresh) {


  dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
              DIVUP(boxes_num, threadsPerBlock));
  dim3 threads(threadsPerBlock);
  nms_kernel<<<blocks, threads>>>(boxes_num,
                                  nms_overlap_thresh,
                                  boxes_dev,
                                  mask_dev);
}


int64_t gpu_nms(at::Tensor keep, at::Tensor boxes, at::Tensor num_out,  double nms_overlap_thresh){
    // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
    // params keep: (N)

    int boxes_num = boxes.size(0);
    float * boxes_data = boxes.data_ptr<float>();
    long * keep_data = keep.data_ptr<long>();

    const int col_blocks = DIVUP(boxes_num, threadsPerBlock);

    //unsigned long long mask_data[boxes_num * col_blocks];
    unsigned long long *mask_data = NULL;
    cudaMalloc((void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long));
    _nms(boxes_num, boxes_data, mask_data, nms_overlap_thresh);

    //unsigned long long mask_cpu[boxes_num * col_blocks];
    //unsigned long long * mask_cpu = new unsigned long long [boxes_num * col_blocks];
    std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);

    //printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
    cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long),
                           cudaMemcpyDeviceToHost);

    cudaFree(mask_data);

    unsigned long long remv_cpu[col_blocks];
    cudaMemset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));

    int num_to_keep = 0;
    for (int i = 0; i < boxes_num; i++){
        int nblock = i / threadsPerBlock;
        int inblock = i % threadsPerBlock;

        if (!(remv_cpu[nblock] & (1ULL << inblock))){
            keep_data[num_to_keep++] = i;
            unsigned long long *p = &mask_cpu[0] + i * col_blocks;
            for (int j = nblock; j < col_blocks; j++){
                remv_cpu[j] |= p[j];
            }
        }
    }
    if ( cudaSuccess != cudaGetLastError() ) printf( "Error!\n" );

    return num_to_keep;
}

TORCH_LIBRARY(pytorch_cmake_example, m) {
  m.def("gpu_nms(Tensor keep, Tensor boxes, Tensor num_out, float nms_overlap_thresh) -> int num_to_keep");
  m.impl("gpu_nms", c10::DispatchKey::CUDA, TORCH_FN(gpu_nms));
  //c10::DispatchKey::CPU is also an option
}

To load the library you can then import it in a python file as have a possibly working 3D NMS as follows:

import torch
torch.ops.load_library("build/libnms_cmake.so")

def nms_gpu(dets, thresh):
  """
  dets has to be a tensor
  """

  scores = dets[:, -1]
  order = scores.sort(0, descending=True)[1]
  dets = dets[order].contiguous()

  keep = torch.LongTensor(dets.size(0))
  num_out = torch.LongTensor(1)
  torch.ops.libnms_cmake.gpu_nms(keep, dets, num_out, thresh)
  return order[keep[:num_out[0]].cuda()].contiguous()

Thanks for the help.