PyTorch and Bazel

For anyone who ends up on this thread looking ways to build PyTorch C++ extensions with Bazel ~8.0.0

Here’s a setup that eventually worked in my case

BUILD file

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@rules_cuda//cuda:defs.bzl", "cuda_library")

cuda_library(
    name = "kernel_cc_lib",
    srcs = [ "kernel.cu", ],
    deps = [
        "@rules_cuda//cuda:runtime",
        "@libtorch_archive//:libtorch",
        "@rules_python//python/cc:current_py_cc_headers",
    ],
    copts = ["-D_GLIBCXX_USE_CXX11_ABI=0"],
)
pybind_extension(
    name = "kernel_wrapper",
    srcs = ["kernel.cpp"],
    deps = [":kernel_cc_lib"],
    copts = ["-DTORCH_EXTENSION_NAME= kernel_wrapper", "-D_GLIBCXX_USE_CXX11_ABI=0"],
)
py_library(
    name = "extension_name",
    srcs = [],
    deps = [],
    data = [":kernel_wrapper"],
    imports = ["."],
)

MODULE.bazel file looks something like

bazel_dep(name = "rules_python", version = "1.0.0")
bazel_dep(name = "pybind11_bazel", version = "2.13.6")

http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
  name = "libtorch_archive",
  strip_prefix = "libtorch",
  type="zip",
  build_file = "@//:libtorch.BUILD",
  urls = ["https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.5.0%2Bcu124.zip"],
  sha256 = "SHA_THAT_WORKS_FORY_YOU",
)

# https://github.com/bazel-contrib/rules_cuda
bazel_dep(name = "rules_cuda", version = "0.2.3")
cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
    name = "local_cuda",
    toolkit_path = "/usr/local/cuda",
)
use_repo(cuda, "local_cuda")

lib torch.BUILD is something like

package(default_visibility = ["//visibility:public"])

cc_library(
    name = "libtorch",
    srcs = glob(["lib/lib*.so*"]),
    deps = [],
    hdrs = glob(["include/**/*.h"]) + glob(["include/**/*.cuh"]),
    includes = [
        "include",
        "include/torch/csrc/api/include",
    ],
)

One can access the wrapper from python like so


import kernel_wrapper

...

Extension follows generic PyTorch guidance

2 Likes