A problem of CUSTOM C++ AND CUDA EXTENSIONS

Hello all,

I’m trying to build a C function for CUDA. With torch.utils.ffi module I can build the source without errors. Because the ffi module is deprecated, so I take the same source code and I use the torch.utils.cpp_extension to build the module, however, it gives compling errors. Here is the C code.

#include <TH/TH.h>
#include <stdbool.h>
#include <stdio.h>

#define real float

# definition of roi_crop.c
int BilinearSamplerBHWD_updateOutput(THFloatTensor *inputImages, THFloatTensor *grids, THFloatTensor *output)
{
  **######The pointer operation trigers the erros######**
  int batchsize = inputImages->size[0];
  int inputImages_height = inputImages->size[1];
  int inputImages_width = inputImages->size[2];
  int output_height = output->size[1];
  int output_width = output->size[2];
  int inputImages_channels = inputImages->size[3];

  int output_strideBatch = output->stride[0];
  int output_strideHeight = output->stride[1];
  int output_strideWidth = output->stride[2];

  int inputImages_strideBatch = inputImages->stride[0];
  int inputImages_strideHeight = inputImages->stride[1];
  int inputImages_strideWidth = inputImages->stride[2];

  int grids_strideBatch = grids->stride[0];
  int grids_strideHeight = grids->stride[1];
  int grids_strideWidth = grids->stride[2];


  real *inputImages_data, *output_data, *grids_data;
  inputImages_data = THFloatTensor_data(inputImages);
  output_data = THFloatTensor_data(output);
  grids_data = THFloatTensor_data(grids);

  int b, yOut, xOut;

  for(b=0; b < batchsize; b++)
  {
    for(yOut=0; yOut < output_height; yOut++)
    {
      for(xOut=0; xOut < output_width; xOut++)
      {
        //read the grid
        real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
        real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];

        // get the weights for interpolation
        int yInTopLeft, xInTopLeft;
        real yWeightTopLeft, xWeightTopLeft;

        real xcoord = (xf + 1) * (inputImages_width - 1) / 2;
        xInTopLeft = floor(xcoord);
        xWeightTopLeft = 1 - (xcoord - xInTopLeft);

        real ycoord = (yf + 1) * (inputImages_height - 1) / 2;
        yInTopLeft = floor(ycoord);
        yWeightTopLeft = 1 - (ycoord - yInTopLeft);



        const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut;
        const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
        const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
        const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
        const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;

        real v=0;
        real inTopLeft=0;
        real inTopRight=0;
        real inBottomLeft=0;
        real inBottomRight=0;

        // we are careful with the boundaries
        bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
        bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
        bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
        bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;

        int t;
        // interpolation happens here
        for(t=0; t<inputImages_channels; t++)
        {
           if(topLeftIsIn) inTopLeft = inputImages_data[inTopLeftAddress + t];
           if(topRightIsIn) inTopRight = inputImages_data[inTopRightAddress + t];
           if(bottomLeftIsIn) inBottomLeft = inputImages_data[inBottomLeftAddress + t];
           if(bottomRightIsIn) inBottomRight = inputImages_data[inBottomRightAddress + t];

           v = xWeightTopLeft * yWeightTopLeft * inTopLeft
             + (1 - xWeightTopLeft) * yWeightTopLeft * inTopRight
             + xWeightTopLeft * (1 - yWeightTopLeft) * inBottomLeft
             + (1 - xWeightTopLeft) * (1 - yWeightTopLeft) * inBottomRight;

           output_data[outAddress + t] = v;
        }

      }
    }
  }

  return 1;
}

Here is the build script for CUDAExtension:

import glob
import os
from setuptools import find_packages, setup
import torch
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension

torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3"


def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, "src")
    
    sources = glob.glob(os.path.join(extensions_dir, "**", "*.c")) + glob.glob(
        os.path.join(extensions_dir, "*.c")
    )
    source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
        os.path.join(extensions_dir, "*.cu")
    )
    
    extension = CppExtension
    
    extra_compile_args = {"cxx": [
    
    ]}

    define_macros = []
    
    if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
        extension = CUDAExtension
        sources += source_cuda
        define_macros += [("WITH_CUDA", None)]
        extra_compile_args["nvcc"] = [
            "-DCUDA_HAS_FP16=1",
            "-D__CUDA_NO_HALF_OPERATORS__",
            "-D__CUDA_NO_HALF_CONVERSIONS__",
            "-D__CUDA_NO_HALF2_OPERATORS__",
        
        ]
        
        # It's better if pytorch can do this by default ..
        CC = os.environ.get("CC", None)
        if CC is not None:
            extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
    
    sources = [os.path.join(extensions_dir, s) for s in sources]
    
    include_dirs = [extensions_dir]
    
    ext_modules = [
        extension(
            "_ext.roi_crop._roi_crop",
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
        )
    ]
    
    return ext_modules


#building Extension
setup(
    name="roi_crop.pytorch",
    python_requires=">=3.6",
    ext_modules=get_extensions(),
    cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
    )

And the errors:

BilinearSampling/roi_crop/src/roi_crop.c:10:38: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int batchsize = inputImages->size[0];
                                      ^
BilinearSampling/roi_crop/src/roi_crop.c:11:47: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int inputImages_height = inputImages->size[1];
                                               ^
BilinearSampling/roi_crop/src/roi_crop.c:12:46: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int inputImages_width = inputImages->size[2];
                                              ^
BilinearSampling/roi_crop/src/roi_crop.c:13:37: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int output_height = output->size[1];
                                     ^
BilinearSampling/roi_crop/src/roi_crop.c:14:36: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int output_width = output->size[2];
                                    ^
BilinearSampling/roi_crop/src/roi_crop.c:15:49: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int inputImages_channels = inputImages->size[3];
                                                 ^
BilinearSampling/roi_crop/src/roi_crop.c:17:44: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int output_strideBatch = output->stride[0];
                                            ^
BilinearSampling/roi_crop/src/roi_crop.c:18:45: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int output_strideHeight = output->stride[1];
                                             ^
BilinearSampling/roi_crop/src/roi_crop.c:19:44: error: invalid types ‘<unresolved overloaded function type>[int]’ for array subscript
   int output_strideWidth = output->stride[2];
                                          ^

As the errors indicate that pointer operation triggers the error. Since it’s C Code, I managed to assign the following two flags to NVCC before setup runs

export CXXFLAGS=“-std=c++11”
export CFLAGS="-std=c99

I got a warning like

cc1plus: warning: command line option ‘-std=c99’ is valid for C/ObjC but not for C++

Since the building is successful with theffi module, it implies that there is no syntactic issue in the codes.
I think I might have failed to give the correct compiling flags with cpp_extension. Any suggestion that helps me out here will be thankful.

This is legacy code, I am not 100% sure.
Can you try, for eg,
THFloatTensor_size(output, 0) and
THFloatTensor_stride(output, 0). Same thing for other indices.

hey @glaringlee thanks a lot for your inputs. Yes, you are right. it’s the code from maskrcnn.
Here is what I tried:

  • I replaced the THCudaTensor stuff with at::Tensor…

The signature of the function BilinearSamplerBHWD_updateOutput looks like

int BilinearSamplerBHWD_updateOutput_cpu(const at::Tensor& inputImages, const at::Tensor& grids, at::Tensor& output)
{
      int batchsize = inputImages.size(0);
      int inputImages_height = inputImages.size(1);
      int inputImages_width = inputImages.size(2);
      int output_height = output.size(0);
      int output_width = output.size(2);
      int inputImages_channels = inputImages.size(3);

      int output_strideBatch = output.stride(0);
      int output_strideHeight = output.stride(1);
      int output_strideWidth = output.stride(2);

      int inputImages_strideBatch = inputImages.stride(0);
      int inputImages_strideHeight = inputImages.stride(1);
      int inputImages_strideWidth = inputImages.stride(2);

      int grids_strideBatch = grids.stride(0);
      int grids_strideHeight = grids.stride(1);
      int grids_strideWidth = grids.stride(2);

      real *inputImages_data, *output_data, *grids_data;
      // TODO: inputImages.scalar_type()
      inputImages_data = inputImages.data_ptr<float>();
      output_data = output.data_ptr<float>();
      grids_data = grids.data_ptr<float>();
      ....

After replacement, the code works, but I guess I should use the template to re-write this function, as **TODO** line highlights, according to the type of inputs, the function can automatically cast the data type as it needs. I’ll re-write all functions with the template later.

  • Second try, as you suggested, I replaced them and complied the codes without errors. The signature of the function BilinearSamplerBHWD_updateOutput looks as follows:
int BilinearSamplerBHWD_updateOutput_cpu(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* output)

However, in this case, the complied functions in .so file are never called. The program seeds like freeze forever. The codes where the C++ function got called are defined as follows:

class RoICropFunction(Function):
    @staticmethod
    def forward(self, input1, input2):
        print("RoICropFunction forward")
        self.input1 = input1.clone()
        self.input2 = input2.clone()
        output = input2.new(input2.size()[0], input1.size()[1], input2.size()[1], input2.size()[2]).zero_()
        assert output.get_device() == input1.get_device(), "output and input1 must on the same device"
        assert output.get_device() == input2.get_device(), "output and input2 must on the same device"
        #  It freeze at this line
        _C.BilinearSamplerBHWD_updateOutput(input1, input2, output)

Any ideas to figure it out, what’s going on there? Thanks in advance! -Anakin

@Anakin
int BilinearSamplerBHWD_updateOutput_cpu(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* output)

Is this the right function you want to paste? It seems this is a function running on cpu side, but accept CUDA tensor. Second thing is that, what you called is _C.BilinearSamplerBHWD_updateOutput without _cpu. Can you check?

@glaringlee Thanks for your inputs again.
I guess I made it a little bitte confused. Here are the complete signatures of all functions as well as the PYBIND11_MODULE. Hopefully I do not make any goofy errors here.


> PYBIND11_MODULE

#include " roi_crop.h"

//namespace densecap {

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("BilinearSamplerBHWD_updateOutput", &BilinearSamplerBHWD_updateOutput, "BilinearSamplerBHWD_updateOutput");
  m.def("BilinearSamplerBHWD_updateGradInput", &BilinearSamplerBHWD_updateGradInput, "BilinearSamplerBHWD_updateGradInput");
}

//} // densecap


> Definition of interface: roi_crop.h

#ifndef ROI_CROP_CPU_H_
#define ROI_CROP_CPU_H_

#include <THC/THC.h>
#include <c10/core/Device.h>

#define USE_AT_TENSOR

#include "cpu/vision.h"

#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif


#ifdef USE_AT_TENSOR
    int BilinearSamplerBHWD_updateOutput(const at::Tensor& inputImages, const at::Tensor& grids, at::Tensor& output) {
        int ret = 1;

        if (inputImages.type().is_cuda()) {
        #ifdef WITH_CUDA
            return BilinearSamplerBHWD_updateOutput_cuda(inputImages, grids, output);
        #else
            AT_ERROR("Not compiled with GPU support");
        #endif // WITH_CUDA
            }

          ret = BilinearSamplerBHWD_updateOutput_cpu(inputImages, grids, output);
          return ret;
    }

    int BilinearSamplerBHWD_updateGradInput(const at::Tensor& inputImages, const at::Tensor& grids, const at::Tensor& gradInputImages,
                                        const at::Tensor& gradGrids,  at::Tensor& gradOutput){
        int ret = 1;

        if (inputImages.type().is_cuda()) {
        #ifdef WITH_CUDA
            return BilinearSamplerBHWD_updateGradInput_cuda(inputImages, grids, gradInputImages, gradGrids, gradOutput);
        #else
            AT_ERROR("Not compiled with GPU support");
        #endif // WITH_CUDA
            }

          ret = BilinearSamplerBHWD_updateGradInput_cpu(inputImages, grids, gradInputImages, gradGrids, gradOutput);
      return ret;
    }
#else
    int BilinearSamplerBHWD_updateOutput(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* output) {
         printf("BilinearSamplerBHWD_updateOutput ");
            int ret = 1;

            if (inputImages->device_type() == c10::DeviceType::CUDA) {
        #ifdef WITH_CUDA
            return BilinearSamplerBHWD_updateOutput_cuda(inputImages, grids, output);
        #else
            AT_ERROR("Not compiled with GPU support");
        #endif // WITH_CUDA
            }

          ret = BilinearSamplerBHWD_updateOutput_cpu(inputImages, grids, output);
          return ret;
     }


    int BilinearSamplerBHWD_updateGradInput(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* gradInputImages,
                                        THCudaTensor* gradGrids,  THCudaTensor* gradOutput){
        int ret = 1;

        //if (inputImages.type().is_cuda()) {
        if (inputImages->device_type() == c10::DeviceType::CUDA) {
        #ifdef WITH_CUDA
            return BilinearSamplerBHWD_updateGradInput_cuda(inputImages, grids, gradInputImages, gradGrids, gradOutput);
        #else
            AT_ERROR("Not compiled with GPU support");
        #endif // WITH_CUDA
            }

          ret = BilinearSamplerBHWD_updateGradInput_cpu(inputImages, grids, gradInputImages, gradGrids, gradOutput);
      return ret;
    }
#endif  //USE_AT_TENSOR
#endif //#define ROI_CROP_CPU_H_

> Definition of interface for cpu <cpu/vision.h>

#pragma once
#include <torch/extension.h>
#define USE_AT_TENSOR

#ifdef USE_AT_TENSOR
    int BilinearSamplerBHWD_updateOutput_cpu(const at::Tensor& inputImages, const at::Tensor& grids, at::Tensor& output);
    int BilinearSamplerBHWD_updateGradInput_cpu(const at::Tensor& inputImages, const at::Tensor& grids, const at::Tensor& gradInputImages,
                                        const at::Tensor& gradGrids, at::Tensor& gradOutput);
#else
    int BilinearSamplerBHWD_updateOutput_cpu(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* output);
    int BilinearSamplerBHWD_updateGradInput_cpu(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* gradInputImages,
                                        THCudaTensor* gradGrids, THCudaTensor* gradOutput);
#endif  // USE_AT_TENSOR

The corresponding definition of interfaces for cuda

> Definition of interface for cuda <cuda/vision.h>
#pragma once
#include <torch/extension.h>
#define USE_AT_TENSOR

#ifdef USE_AT_TENSOR
    int BilinearSamplerBHWD_updateOutput_cuda(const at::Tensor& inputImages, const at::Tensor& grids, at::Tensor& output);
    int BilinearSamplerBHWD_updateGradInput_cuda(const at::Tensor& inputImages, const at::Tensor& grids, const at::Tensor& gradInputImages,
                                        const at::Tensor& gradGrids, at::Tensor& gradOutput);
#else
    int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* output);
    int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor* inputImages, THCudaTensor* grids, THCudaTensor* gradInputImages,
                                        THCudaTensor* gradGrids, THCudaTensor* gradOutput);
#endif // USE_AT_TENSOR

In general, BilinearSamplerBHWD_updateOutput is the interface for python codes.
At the running time, depending on the device info, the functions are directed to call BilinearSamplerBHWD_updateOutput_cpu or BilinearSamplerBHWD_updateOutput_cuda.

Here is a brief schematic representation of this calling sequence.

Python Module --> BilinearSamplerBHWD_updateOutput 
                        | -->BilinearSamplerBHWD_updateOutput_cpu
                        |-->BilinearSamplerBHWD_updateOutput_cuda  

The definition of Python Module is given as follows:

class RoICropFunction(Function):
    @staticmethod
    def forward(self, input1, input2):
        self.input1 = input1.clone()
        self.input2 = input2.clone()
        output = input2.new(input2.size()[0], input1.size()[1], input2.size()[1], input2.size()[2]).zero_()
        assert output.get_device() == input1.get_device(), "output and input1 must on the same device"
        assert output.get_device() == input2.get_device(), "output and input2 must on the same device"
        # I'll freeze here if C++ interface is declared with THCudaTensor
        _C.BilinearSamplerBHWD_updateOutput(input1, input2, output)
        self.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        #result, = ctx.saved_tensors
        grad_input1 = ctx.input1.new(ctx.input1.size()).zero_()
        grad_input2 = ctx.input2.new(ctx.input2.size()).zero_()
        _C.BilinearSamplerBHWD_updateGradInput(ctx.input1, ctx.input2, grad_input1, grad_input2,
                                                          grad_output)
        return grad_input1, grad_input2

Hopefully, this helps us to locate the bugs. Thanks.

@Anakin
I see. I think the problem is that in your python code. All the tensors (eg, your input tensor) will be converted to at::Tensor , so if your BilinearSamplerBHWD_xxxx functions accept legacy TH tensors, it won’t get called. My suggestion previous only solved the compile problem. I recommend you to switch to at::Tensor, since TH is legacy, we don’t have any more support on that.

I currently not sure how to let BilinearSamplerBHWD_xxxx get called with TH tensors under cpp extension, got a bit busy in the following 2 days, let me know if you want to stick with TH tensors, I can research on this after this Friday.

@glaringlee Thanks. I’d like to use at::tensor. Cheers.