x2.cpp
#include <torch/extension.h>
#include <iostream>
#define _OPENMP
#include <ATen/ParallelOpenMP.h>
torch::Tensor x2(torch::Tensor z)
{
torch::Tensor z_out = at::empty({z.size(0), z.size(1)}, z.options());
int64_t batch_size = z.size(0);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (int64_t b = start; b < end; b++)
{
z_out[b] = z[b] * z[b];
}
std::cout << "hi there from " << omp_get_thread_num() << std::endl;
});
return z_out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("x2", &x2, "square");
}
setup.py
from setuptools import setup, Extension
from torch.utils import cpp_extension
cpp_module = cpp_extension.CppExtension('x2_cpp',
sources=['x2.cpp'],
extra_compile_args=['-fopenmp'],
extra_link_args=['-lgomp']
)
setup(name='x2_cpp',
ext_modules=[cpp_module],
cmdclass={'build_ext': cpp_extension.BuildExtension})
python setup.py install
build output:
running install
running bdist_egg
running egg_info
writing x2_cpp.egg-info/PKG-INFO
writing dependency_links to x2_cpp.egg-info/dependency_links.txt
writing top-level names to x2_cpp.egg-info/top_level.txt
reading manifest file 'x2_cpp.egg-info/SOURCES.txt'
writing manifest file 'x2_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'x2_cpp' extension
Emitting ninja build file *my_path :)*/x2_cpp/build/temp.linux-x86_64-3.7/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/1] c++ -MMD -MF my_path/x2_cpp/build/temp.linux-x86_64-3.7/x2.o.d -pthread -B my_path/anaconda3/envs/PyTorch/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/torch/include -I/home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/torch/include/TH -I/home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/torch/include/THC -I/home/my_path/anaconda3/envs/PyTorch/include/python3.7m -c -c /home/my_path/workspace/MotionSparsity/MSBackEnd/x2_cpp/x2.cpp -o /home/my_path/workspace/MotionSparsity/MSBackEnd/x2_cpp/build/temp.linux-x86_64-3.7/x2.o -fopenmp -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=x2_cpp -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
g++ -pthread -shared -B /home/my_path/anaconda3/envs/PyTorch/compiler_compat -L/home/my_path/anaconda3/envs/PyTorch/lib -Wl,-rpath=/home/my_path/anaconda3/envs/PyTorch/lib -Wl,--no-as-needed -Wl,--sysroot=/ /home/my_path/workspace/MotionSparsity/MSBackEnd/x2_cpp/build/temp.linux-x86_64-3.7/x2.o -L/home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-3.7/x2_cpp.cpython-37m-x86_64-linux-gnu.so -lgomp
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.7/x2_cpp.cpython-37m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for x2_cpp.cpython-37m-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/x2_cpp.py to x2_cpp.cpython-37.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying x2_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying x2_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying x2_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying x2_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.x2_cpp.cpython-37: module references __file__
creating 'dist/x2_cpp-0.0.0-py3.7-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing x2_cpp-0.0.0-py3.7-linux-x86_64.egg
removing '/home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/x2_cpp-0.0.0-py3.7-linux-x86_64.egg' (and everything under it)
creating /home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/x2_cpp-0.0.0-py3.7-linux-x86_64.egg
Extracting x2_cpp-0.0.0-py3.7-linux-x86_64.egg to /home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages
x2-cpp 0.0.0 is already the active version in easy-install.pth
Installed /home/my_path/anaconda3/envs/PyTorch/lib/python3.7/site-packages/x2_cpp-0.0.0-py3.7-linux-x86_64.egg
Processing dependencies for x2-cpp==0.0.0
Finished processing dependencies for x2-cpp==0.0.0
test.py
import torch
import x2_cpp
a = torch.rand((2, 2))
print(a)
print(x2_cpp.x2(a))
output:
tensor([[0.6516, 0.4836],
[0.8497, 0.2106]])
hi there from 0
hi there from 1
tensor([[0.4246, 0.2338],
[0.7220, 0.0444]])
Maybe it’s an ugly way to use this, but it does work.