For anyone who ends up on this thread looking ways to build PyTorch C++ extensions with Bazel ~8.0.0
Here’s a setup that eventually worked in my case
BUILD
file
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@rules_cuda//cuda:defs.bzl", "cuda_library")
cuda_library(
name = "kernel_cc_lib",
srcs = [ "kernel.cu", ],
deps = [
"@rules_cuda//cuda:runtime",
"@libtorch_archive//:libtorch",
"@rules_python//python/cc:current_py_cc_headers",
],
copts = ["-D_GLIBCXX_USE_CXX11_ABI=0"],
)
pybind_extension(
name = "kernel_wrapper",
srcs = ["kernel.cpp"],
deps = [":kernel_cc_lib"],
copts = ["-DTORCH_EXTENSION_NAME= kernel_wrapper", "-D_GLIBCXX_USE_CXX11_ABI=0"],
)
py_library(
name = "extension_name",
srcs = [],
deps = [],
data = [":kernel_wrapper"],
imports = ["."],
)
MODULE.bazel
file looks something like
bazel_dep(name = "rules_python", version = "1.0.0")
bazel_dep(name = "pybind11_bazel", version = "2.13.6")
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "libtorch_archive",
strip_prefix = "libtorch",
type="zip",
build_file = "@//:libtorch.BUILD",
urls = ["https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.5.0%2Bcu124.zip"],
sha256 = "SHA_THAT_WORKS_FORY_YOU",
)
# https://github.com/bazel-contrib/rules_cuda
bazel_dep(name = "rules_cuda", version = "0.2.3")
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
name = "local_cuda",
toolkit_path = "/usr/local/cuda",
)
use_repo(cuda, "local_cuda")
lib torch.BUILD
is something like
package(default_visibility = ["//visibility:public"])
cc_library(
name = "libtorch",
srcs = glob(["lib/lib*.so*"]),
deps = [],
hdrs = glob(["include/**/*.h"]) + glob(["include/**/*.cuh"]),
includes = [
"include",
"include/torch/csrc/api/include",
],
)
One can access the wrapper from python like so
import kernel_wrapper
...
Extension follows generic PyTorch guidance