How to get the pointer of torch.cuda.HalfTensor?

Hi everyone, I am trying to write some custom CUDA kernel functions for torch.cuda.HalfTensor. However, I don’t know how to get pointer of torch.cuda.HalfTensor. Here are my codes:


#include <iostream>

#include <torch/extension.h>

torch::Tensor mul(torch::Tensor & x, torch::Tensor & y);


    m.def("mul", &mul);


#include <cuda.h>

#include <cuda_runtime.h>

#include <torch/extension.h>

#include <ATen/ATen.h>

#include <math.h>

#include <stdio.h>

#include <cuda_fp16.h>

__global__ void mul_cuda_kernel(const __half* __restrict__ x, const __half* __restrict__ y, __half* __restrict__ z, const int size)


  const int index = blockIdx.x * blockDim.x + threadIdx.x;

  if (index < size) {

      z[index] = __hmul(x[index], y[index]);



torch::Tensor mul(torch::Tensor & x, torch::Tensor & y)


    TORCH_CHECK(x.device().is_cuda(), "x must be a CUDA tensor");

    TORCH_CHECK(y.device().is_cuda(), "y must be a CUDA tensor");

    if (! x.is_contiguous())


        x = x.contiguous();


    if (! y.is_contiguous())


        y = y.contiguous();


    auto z = torch::zeros_like(;

    const int size = x.numel();

    const int threads = 1024;

    const int blocks = (size + threads - 1) / threads;


    mul_cuda_kernel<<<blocks, threads>>>(x.data_ptr<__half>(), y.data_ptr<__half>(), z.data_ptr<__half>(), size);

    return z;   


import torch

import torch.nn as nn

import torch.nn.functional as F

from torch.utils import cpp_extension

cext_temp = cpp_extension.load(name='cext_temp',

    sources=['./c.cpp', './'], verbose=True)

device = 'cuda:0'

x = torch.rand([8], device=device).half()

y = torch.rand([8], device=device).half()




print(cext_temp.mul(x, y))

When I run, I get the error:

  File "/home/wfang/spikingjellyCppExt/temp/", line 6, in <module>
    cext_temp = cpp_extension.load(name='cext_temp',
  File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/site-packages/torch/utils/", line 969, in load
    return _jit_compile(
  File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/site-packages/torch/utils/", line 1196, in _jit_compile
    return _import_module_from_library(name, build_directory, is_python_module)
  File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/site-packages/torch/utils/", line 1547, in _import_module_from_library
    return imp.load_module(module_name, file, path, description)
  File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/", line 242, in load_module
    return load_dynamic(name, filename, file)
  File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/", line 342, in load_dynamic
    return _load(spec)
  File "<frozen importlib._bootstrap>", line 702, in _load
  File "<frozen importlib._bootstrap>", line 657, in _load_unlocked
  File "<frozen importlib._bootstrap>", line 556, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 1101, in create_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /home/wfang/.cache/torch_extensions/cext_temp/ undefined symbol: _ZNK2at6Tensor8data_ptrI6__halfEEPT_v

I find that I must use at::Half rather than half. Another question is how can I use half2? It seems that at::Half2 does not exist.

Hey, I also encountered this question. May I ask how you convent at::Half to half while computing in cuda kernel? Thanks for your help!

In c++, you can use x.data_ptr<at::Half>() to get a half pinter half* x, which is used in cuda kernel.

Thanks! I’m just not sure whether the at::Half is just the same half or not.
I use the (half*)(x.data_ptr<at::Half>()) to get the pointer. It works fine for now :relaxed: .