Hi everyone, I am trying to write some custom CUDA kernel functions for torch.cuda.HalfTensor. However, I don’t know how to get pointer of torch.cuda.HalfTensor. Here are my codes:
1 Like
c.cpp:
#include <iostream>
#include <torch/extension.h>
torch::Tensor mul(torch::Tensor & x, torch::Tensor & y);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mul", &mul);
}
kernel.cu:
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <math.h>
#include <stdio.h>
#include <cuda_fp16.h>
__global__ void mul_cuda_kernel(const __half* __restrict__ x, const __half* __restrict__ y, __half* __restrict__ z, const int size)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
z[index] = __hmul(x[index], y[index]);
}
}
torch::Tensor mul(torch::Tensor & x, torch::Tensor & y)
{
TORCH_CHECK(x.device().is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(y.device().is_cuda(), "y must be a CUDA tensor");
if (! x.is_contiguous())
{
x = x.contiguous();
}
if (! y.is_contiguous())
{
y = y.contiguous();
}
auto z = torch::zeros_like(x.data());
const int size = x.numel();
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
cudaSetDevice(x.get_device());
mul_cuda_kernel<<<blocks, threads>>>(x.data_ptr<__half>(), y.data_ptr<__half>(), z.data_ptr<__half>(), size);
return z;
}
test.py:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import cpp_extension
cext_temp = cpp_extension.load(name='cext_temp',
sources=['./c.cpp', './kernel.cu'], verbose=True)
device = 'cuda:0'
x = torch.rand([8], device=device).half()
y = torch.rand([8], device=device).half()
print(x)
print(y)
print(x*y)
print(cext_temp.mul(x, y))
When I run test.py, I get the error:
File "/home/wfang/spikingjellyCppExt/temp/test.py", line 6, in <module>
cext_temp = cpp_extension.load(name='cext_temp',
File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 969, in load
return _jit_compile(
File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1196, in _jit_compile
return _import_module_from_library(name, build_directory, is_python_module)
File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1547, in _import_module_from_library
return imp.load_module(module_name, file, path, description)
File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/imp.py", line 242, in load_module
return load_dynamic(name, filename, file)
File "/home/wfang/anaconda3/envs/pytorch-env/lib/python3.8/imp.py", line 342, in load_dynamic
return _load(spec)
File "<frozen importlib._bootstrap>", line 702, in _load
File "<frozen importlib._bootstrap>", line 657, in _load_unlocked
File "<frozen importlib._bootstrap>", line 556, in module_from_spec
File "<frozen importlib._bootstrap_external>", line 1101, in create_module
File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /home/wfang/.cache/torch_extensions/cext_temp/cext_temp.so: undefined symbol: _ZNK2at6Tensor8data_ptrI6__halfEEPT_v
I find that I must use at::Half
rather than half
. Another question is how can I use half2
? It seems that at::Half2
does not exist.
Hey, I also encountered this question. May I ask how you convent at::Half
to half
while computing in cuda kernel? Thanks for your help!
In c++, you can use x.data_ptr<at::Half>()
to get a half pinter half* x
, which is used in cuda kernel.
Thanks! I’m just not sure whether the at::Half
is just the same half
or not.
I use the (half*)(x.data_ptr<at::Half>())
to get the pointer. It works fine for now .