Is there a clean way to get a native pointer out of an fp16 tensor inside AT_DISPATCH_FLOATING_TYPES_AND_HALF
? For other floating point types (fp32, fp64), I can use tensor.data_ptr<scalar_t>()
and get the right native pointer type out. For fp16, I get back c10::Half*
which isn’t type-compatible with __half*
.
The best workaround I’ve come up with is:
template<typename U>
struct native_type {
using T = U;
};
template<>
struct native_type<c10::Half> {
using T = __half;
};
template<typename U>
typename native_type<U>::T* ptr(Tensor t) {
return reinterpret_cast<typename native_type<U>::T*>(t.data_ptr<U>());
}
And then I’d call ptr<scalar_t>(tensor)
every place I would normally call tensor.data_ptr<scalar_t>()
. I also have to use typename native_type<scalar_t>::T
for all template arguments that need __half
instead of c10::Half
.
Is there a better alternative or existing code in PyTorch that I can reuse?