Relation between at::Half and __half

Hi, we want to create atomicMax function with __half inputs by following pytorch_scatter/atomics.cuh at master. The detailed codes are

inline __device__ void operator()(at::Half *address, at::Half val) {       \
  unsigned int *address_as_ui =                                            \
      (unsigned int *)((char *)address - ((size_t)address & 2));           \
  unsigned int old = *address_as_ui;                                       \
  unsigned int assumed;                                                    \
                                                                           \
  do {                                                                     \
    assumed = old;                                                         \
    at::Half hsum;                                                         \
    hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);           \
    hsum = OP(hsum, val);                                                  \
    old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)            \
                              : (old & 0xffff0000) | hsum.x;               \
    old = atomicCAS(address_as_ui, assumed, old);                          \
  } while (assumed != old);                                                \
}                                                                          \

When it comes to the __half type, however, we do not how to obtain the hsum.x. Are there some ways to convert as::Half and __half each other?

Also, how could we implement OP as max to deal with __half inputs?

Thanks : )

It’s a bit clunky, but I believe simply:
*reinterpret_cast<__half*>(&myathalf);
and
*reinterpret_cast<at::Half*>(&myhalf);
would work. To convert pointers you would simply drop the dereference.

Thanks for your advice. Actually, we want to create atomicMax just in normal .cu file without at::Half defination. So, are there some ways to correctly replace at::Half by __half in above function?

I am not familiar with pytorch/Half.h at master but simple replacement would bring error class "__half" has no member "x"

What happens if you just use the value directly (remove the .x), if hsum is type __half?

Here are my modified codes

  unsigned int *address_as_ui = (unsigned int *)((char *)address - ((size_t)address & 2));
  unsigned int old = *address_as_ui;
  unsigned int assumed;
  do {
    assumed = old;
    __half hsum;
    hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
    hsum = __float2half(fmaxf(__half2float(hsum), __half2float(val)));
    old = (size_t)address & 2 ? (old & 0xffff) | (hsum << 16)
                              : (old & 0xffff0000) | hsum;
    old = atomicCAS(address_as_ui, assumed, old);
  } while (assumed != old);

The error appears in two hsum in old calculation:

test.cu(57): error: more than one conversion function from "__half" to a built-in type applies:
            function "__half::operator short() const"
/cuda-11.4/include/cuda_fp16.hpp(222): here
            function "__half::operator unsigned short() const"
/cuda-11.4/include/cuda_fp16.hpp(225): here
            function "__half::operator int() const"
/cuda-11.4/include/cuda_fp16.hpp(228): here
            function "__half::operator unsigned int() const"
/cuda-11.4/include/cuda_fp16.hpp(231): here
            function "__half::operator long long() const"
/cuda-11.4/include/cuda_fp16.hpp(234): here
            function "__half::operator unsigned long long() const"
/cuda-11.4/include/cuda_fp16.hpp(237): here
            function "__half::operator __nv_bool() const"
/cuda-11.4/include/cuda_fp16.hpp(241): here

What happens if you explicitly cast hsum (and/or) old to short as x would be in the original implementation?

Do you mean

  unsigned int *address_as_ui = (unsigned int *)((char *)address - ((size_t)address & 2));
  unsigned int old = *address_as_ui;
  unsigned int assumed;
  do {
    assumed = old;
    __half hsum;
    hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
    hsum = __float2half(fmaxf(__half2float(hsum), __half2float(val)));
    old = (size_t)address & 2 ? (old & 0xffff) | (__half_as_ushort(hsum) << 16)
                              : (old & 0xffff0000) | __half_as_ushort(hsum);
    old = atomicCAS(address_as_ui, assumed, old);
  } while (assumed != old);

This works and the compile error disappears. But the calculated result seems wrong…

Would it make a difference if it was converted to a signed short?

I use __half_as_short to replace __half_as_ushort but the calculation is still wrong.
Now we have

__device__ static void atomicMax(__half* address, __half val){

  unsigned int *address_as_ui = (unsigned int *)((char *)address - ((size_t)address & 2));
  unsigned int old = *address_as_ui;
  unsigned int assumed;
  do {
    assumed = old;
    __half hsum;
    hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
    hsum = __float2half(fmaxf(__half2float(hsum), __half2float(val)));
    old = (size_t)address & 2 ? (old & 0xffff) | (__half_as_short(hsum) << 16)
                              : (old & 0xffff0000) | __half_as_short(hsum);
    old = atomicCAS(address_as_ui, assumed, old);
  } while (assumed != old);
}

At the same time, I found error: more than one conversion function from “half” to a built-in type applies · Issue #4 · NVIDIA/cutlass (github.com). I try to add

set(CMAKE_CUDA_ARCHITECTURES 86)
or
set(CUDA_ARCHITECTURE_FLAGS "86")

in my CMakeLists.txt but still does not work. What is the correct way to use above command in cmake?