Here’s my code.
file: q_cu.cpp
#include <torch/torch.h>
#include <cmath>
void run_cu(at::Tensor& out, const at::Tensor& x, int bw, int fl);
at::Tensor run(const at::Tensor x, int bw, int fl) {
at::Tensor out = at::empty(x.sizes(), x.options());
run_cu(out, x, bw, fl);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run", &run, "run");
}
file: q_cu_core.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cmath>
void run_cu(at::Tensor& out, const at::Tensor& x, int bw, int fl) {
int64_t q_max = (1 << (bw - 1)) - 1;
int64_t q_min = -(1 << (bw - 1));
float scale = pow(2.0, fl);
float inv_scale = 1.0 / scale;
at::cuda::CUDA_tensor_apply2<float, float>(
x, out, [=] __device__(const float& src, float& dst) {
dst =(fminf(q_max,
fmaxf(q_min,
static_cast<int64_t>(std::round(val * scale))
)
) * inv_scale;
}
);
}
file: setup.py
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
setup(
name="my_method",
ext_modules=[
CUDAExtension("my_method2",
["q_cu.cpp", "q_cu_core.cu"],
extra_compile_args=["-expt-extened-lambda"])
]
cmdclass={"build_ext": BuildExtension}
)
There is no error when I compile this extension. But if I import it,
import torch
import my_method2
# output
# ImportError...
# undefined symbol:
# _ZN2at6native6legacy4cuda27_th_copy_ignoring_overlaps_ERNS_6TensorERKS3_
How can I sovle this problem?