Getting `__half*` out of an fp16 tensor?

Is there a clean way to get a native pointer out of an fp16 tensor inside AT_DISPATCH_FLOATING_TYPES_AND_HALF? For other floating point types (fp32, fp64), I can use tensor.data_ptr<scalar_t>() and get the right native pointer type out. For fp16, I get back c10::Half* which isn’t type-compatible with __half*.

The best workaround I’ve come up with is:

template<typename U>
struct native_type {
  using T = U;
};

template<>
struct native_type<c10::Half> {
  using T = __half;
};

template<typename U>
typename native_type<U>::T* ptr(Tensor t) {
  return reinterpret_cast<typename native_type<U>::T*>(t.data_ptr<U>());
}

And then I’d call ptr<scalar_t>(tensor) every place I would normally call tensor.data_ptr<scalar_t>(). I also have to use typename native_type<scalar_t>::T for all template arguments that need __half instead of c10::Half.

Is there a better alternative or existing code in PyTorch that I can reuse?