Error: more than one operator "<" matches these operands in THCNumerics.cuh

We are trying to build a libtorch based C++ project using customized kernels which are written in .cu files, but encountered the following compiling errors. By commenting out line 191-197 in the THCNumerics.cuh, we are able to get a clean build. We are using CUDA 10.2 and cuDNN v7.6.5 on a V100 GPU server.

I wonder what is the real problem here and what is the best way to solve it. Thanks!

./libtorch/include/THC/THCNumerics.cuh(191): error: more than one operator “<” matches these operands:
built-in operator “arithmetic < arithmetic”
function “operator<(const __half &, const __half &)”
operand types are: c10::Half < c10::Half

./libtorch/include/THC/THCNumerics.cuh(192): error: more than one operator “<=” matches these operands:
built-in operator “arithmetic <= arithmetic”
function “operator<=(const __half &, const __half &)”
operand types are: c10::Half <= c10::Half

./libtorch/include/THC/THCNumerics.cuh(193): error: more than one operator “>” matches these operands:
built-in operator “arithmetic > arithmetic”
function “operator>(const __half &, const __half &)”
operand types are: c10::Half > c10::Half

./libtorch/include/THC/THCNumerics.cuh(194): error: more than one operator “>=” matches these operands:
built-in operator “arithmetic >= arithmetic”
function “operator>=(const __half &, const __half &)”
operand types are: c10::Half >= c10::Half

./libtorch/include/THC/THCNumerics.cuh(195): error: more than one operator “==” matches these operands:
built-in operator “arithmetic == arithmetic”
function “operator==(const __half &, const __half &)”
operand types are: c10::Half == c10::Half

./libtorch/include/THC/THCNumerics.cuh(197): error: more than one operator “!=” matches these operands:
built-in operator “arithmetic != arithmetic”
function “operator!=(const __half &, const __half &)”
operand types are: c10::Half != c10::Half

From what I recall, you want to disable the built-in fp16 ops for PyTorch, something like

-DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__

or somesuch.
Quite likely the flags PyTorch itself builds with are indicative for what you need.
If your search the github issues for the flags, you might find some more background. I seem to recall that there was some issue with automatic casts or so.

Best regards

Thomas

1 Like

Thank you! Do we need to set these flags when we build libtorch? Or is there any way to set the flags somewhere in the build libtorch files?

So I only build libtorch through python3 setup.py bdist_wheel (throwing away the Python parts), but in principle, cmake builds also should set this for you.

Should we set -DCUDA_HAS_FP16=0 instead of 1?

No. You want fp16 but not the built-in ops.

Got it, thanks a lot!

Hello, I encountered the same problem when building a libtorch based C++ project. Those flags(-DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__) are set in the CMakeLists.txt as CUDA_NVCC_FLAGS, but it can’t work. Should I rebuild the libtorch, or how to solve it in the right way, thanks

I faced an issue similar to this while installing torchvision from source. My error was

/home/user/miniconda3/envs/blender/lib/python3.10/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh(52): error: more than one operator "=" matches these operands:
            function "__half::operator=(float)"
/usr/local/cuda/include/cuda_fp16.hpp(218): here
            function "__half::operator=(__half &&)" (declared implicitly)
            operand types are: __half = c10::Half
          detected during:
            instantiation of "void at::native::fastSpecializedAtomicAdd(scalar_t *, index_t, index_t, scalar_t) [with scalar_t=c10::Half, index_t=int64_t, <unnamed>=(void *)nullptr]" 
(131): here
            instantiation of "void at::native::fastAtomicAdd(scalar_t *, index_t, index_t, scalar_t, __nv_bool) [with scalar_t=c10::Half, index_t=int64_t]" 
/home/user/vision/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu(391): here
            instantiation of "void vision::ops::<unnamed>::deformable_col2im_kernel(index_t, const scalar_t *, const scalar_t *, const scalar_t *, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, index_t, __nv_bool, scalar_t *) [with scalar_t=c10::Half, index_t=int64_t]" 
/home/user/vision/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu(439): here

Adding these flags helped. For those who are wondering how to add these flags, you can add them in CMakeLists.txt. E.g. from the CMakeLists.txt in torchvision,

if(WITH_CUDA)
  enable_language(CUDA)
  add_definitions(-D__CUDA_NO_HALF_OPERATORS__)
  add_definitions(-D__CUDA_NO_HALF_CONVERSIONS__)
  add_definitions(-D__CUDA_NO_HALF2_OPERATORS__)
  add_definitions(-DWITH_CUDA)
  set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif()