How to build a C++/CUDA extension capable with different PyTorch (e.g. `cpu`, `cu102`, `cu116`)

Is it possible to build generic extensions that are suitable for different PyTorch builds (e.g. cpu, cu102, cu113, cu116)?

Originally posted at [BUG] Separate CPU / CUDA wheels · Issue #46 · metaopt/TorchOpt · GitHub.

I’m building C++/CUDA extensions for TorchOpt. Then ship them in wheels.

Wheels built with torch==1.12.0+cu116 is incompatible with torch==1.12.0+cpu (see CI output https://github.com/metaopt/TorchOpt/runs/7553102052 for more details). Different torch build ships with different libraries:

torch==1.12.0+cu116:

$ ls $SITE_PACKAGES/torch/lib
total 3.3G
-rwxr-xr-x 1 root root 1.2M Jul 28 04:01 libc10_cuda.so
-rwxr-xr-x 1 root root 751K Jul 28 04:01 libc10.so
-rwxr-xr-x 1 root root  25K Jul 28 04:01 libcaffe2_nvrtc.so
-rwxr-xr-x 1 root root 335M Jul 28 04:01 libcublasLt.so.11
-rwxr-xr-x 1 root root 150M Jul 28 04:01 libcublas.so.11
-rwxr-xr-x 1 root root 668K Jul 28 04:01 libcudart-45da57e3.so.11.0
-rwxr-xr-x 1 root root 124M Jul 28 04:01 libcudnn_adv_infer.so.8
-rwxr-xr-x 1 root root  92M Jul 28 04:01 libcudnn_adv_train.so.8
-rwxr-xr-x 1 root root 774M Jul 28 04:01 libcudnn_cnn_infer.so.8
-rwxr-xr-x 1 root root  85M Jul 28 04:01 libcudnn_cnn_train.so.8
-rwxr-xr-x 1 root root  86M Jul 28 04:01 libcudnn_ops_infer.so.8
-rwxr-xr-x 1 root root  68M Jul 28 04:01 libcudnn_ops_train.so.8
-rwxr-xr-x 1 root root 155K Jul 28 04:01 libcudnn.so.8
-rwxr-xr-x 1 root root 165K Jul 28 04:01 libgomp-a34b3233.so.1
-rwxr-xr-x 1 root root  44M Jul 28 04:01 libnvrtc-4dd39364.so.11.2
-rwxr-xr-x 1 root root 6.8M Jul 28 04:01 libnvrtc-builtins.so.11.6
-rwxr-xr-x 1 root root  43K Jul 28 04:01 libnvToolsExt-847d78f2.so.1
-rwxr-xr-x 1 root root  44K Jul 28 04:01 libshm.so
-rwxr-xr-x 1 root root 487M Jul 28 04:01 libtorch_cpu.so
-rwxr-xr-x 1 root root 216M Jul 28 04:01 libtorch_cuda_cpp.so
-rwxr-xr-x 1 root root 653M Jul 28 04:01 libtorch_cuda_cu.so
-rwxr-xr-x 1 root root 209M Jul 28 04:01 libtorch_cuda_linalg.so
-rwxr-xr-x 1 root root 163K Jul 28 04:01 libtorch_cuda.so
-rwxr-xr-x 1 root root  21K Jul 28 04:01 libtorch_global_deps.so
-rwxr-xr-x 1 root root  21M Jul 28 04:01 libtorch_python.so
-rwxr-xr-x 1 root root  16K Jul 28 04:01 libtorch.so

torch==1.12.0+cpu:

$ ls $SITE_PACKAGES/torch/lib
total 496M
-rwxr-xr-x 1 root root 269K Jul 28 04:02 libbackend_with_compiler.so
-rwxr-xr-x 1 root root 766K Jul 28 04:02 libc10.so
-rwxr-xr-x 1 root root 165K Jul 28 04:02 libgomp-a34b3233.so.1
-rwxr-xr-x 1 root root 228K Jul 28 04:02 libjitbackend_test.so
-rwxr-xr-x 1 root root  35K Jul 28 04:02 libshm.so
-rwxr-xr-x 1 root root 588K Jul 28 04:02 libtorchbind_test.so
-rwxr-xr-x 1 root root 476M Jul 28 04:02 libtorch_cpu.so
-rwxr-xr-x 1 root root 8.6K Jul 28 04:02 libtorch_global_deps.so
-rwxr-xr-x 1 root root  19M Jul 28 04:02 libtorch_python.so
-rwxr-xr-x 1 root root 7.1K Jul 28 04:02 libtorch.so

In our .cxx and .cu code, we only have one include directive #include <torch/extension.h> and only referenced torch::Tensor and AT_DISPATCH_FLOATING_TYPES. But the built shared libraries are linking against too many libraries than expected.

$ ldd /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/adam_op.cpython-37m-x86_64-linux-gnu.so
        linux-vdso.so.1 =>  (0x00007ffcd44ea000)
        libc10.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libc10.so (0x00007f75c1243000)
        libc10_cuda.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libc10_cuda.so (0x00007f75c109b000)
        libcaffe2_nvrtc.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libcaffe2_nvrtc.so (0x00007f75c123c000)
        libshm.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libshm.so (0x00007f75c1231000)
        libtorch.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch.so (0x00007f75c122c000)
        libtorch_cpu.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_cpu.so (0x00007f75a704d000)
        libtorch_cuda.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_cuda.so (0x00007f75c120c000)
        libtorch_cuda_cpp.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_cuda_cpp.so (0x00007f7599db0000)
        libtorch_cuda_cu.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_cuda_cu.so (0x00007f757237b000)
        libtorch_cuda_linalg.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_cuda_linalg.so (0x00007f7565905000)
        libtorch_global_deps.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_global_deps.so (0x00007f75c1203000)
        libtorch_python.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_python.so (0x00007f7564957000)
        librt.so.1 => /lib64/librt.so.1 (0x00007f756474f000)
        libpthread.so.0 => /lib64/libpthread.so.0 (0x00007f7564533000)
        libdl.so.2 => /lib64/libdl.so.2 (0x00007f756432f000)
        libstdc++.so.6 => /lib64/libstdc++.so.6 (0x00007f7564027000)
        libm.so.6 => /lib64/libm.so.6 (0x00007f7563d25000)
        libgomp.so.1 => /lib64/libgomp.so.1 (0x00007f7563aff000)
        libgcc_s.so.1 => /lib64/libgcc_s.so.1 (0x00007f75638e9000)
        libc.so.6 => /lib64/libc.so.6 (0x00007f756351b000)
        /lib64/ld-linux-x86-64.so.2 (0x00007f75c1199000)
        libgomp-a34b3233.so.1 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libgomp-a34b3233.so.1 (0x00007f75632f1000)
        libcuda.so.1 => /lib64/libcuda.so.1 (0x00007f7561e96000)
        libnvrtc-4dd39364.so.11.2 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libnvrtc-4dd39364.so.11.2 (0x00007f755f075000)
        libcudart-45da57e3.so.11.0 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libcudart-45da57e3.so.11.0 (0x00007f755edcd000)
        libnvToolsExt-847d78f2.so.1 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libnvToolsExt-847d78f2.so.1 (0x00007f755ebc2000)
        libcudnn.so.8 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libcudnn.so.8 (0x00007f755e99a000)
        libcublas.so.11 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libcublas.so.11 (0x00007f755521c000)
        libcublasLt.so.11 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libcublasLt.so.11 (0x00007f75401b6000)

Then the wheel built with torch==1.12.0+cu116 is incompatible with torch==1.12.0+cpu because it cannot found the shared libraries.

$ ldd /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/adam_op.cpython-37m-x86_64-linux-gnu.so
        linux-vdso.so.1 =>  (0x00007ffd5997a000)
        libc10.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libc10.so (0x00007ff16efde000)
        libc10_cuda.so => not found
        libcaffe2_nvrtc.so => not found
        libshm.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libshm.so (0x00007ff16efce000)
        libtorch.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch.so (0x00007ff16efcb000)
        libtorch_cpu.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_cpu.so (0x00007ff155b9a000)
        libtorch_cuda.so => not found
        libtorch_cuda_cpp.so => not found
        libtorch_cuda_cu.so => not found
        libtorch_cuda_linalg.so => not found
        libtorch_global_deps.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_global_deps.so (0x00007ff16efc5000)
        libtorch_python.so => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libtorch_python.so (0x00007ff154dc0000)
        librt.so.1 => /lib64/librt.so.1 (0x00007ff154bb8000)
        libpthread.so.0 => /lib64/libpthread.so.0 (0x00007ff15499c000)
        libdl.so.2 => /lib64/libdl.so.2 (0x00007ff154798000)
        libstdc++.so.6 => /lib64/libstdc++.so.6 (0x00007ff154490000)
        libm.so.6 => /lib64/libm.so.6 (0x00007ff15418e000)
        libgomp.so.1 => /lib64/libgomp.so.1 (0x00007ff153f68000)
        libgcc_s.so.1 => /lib64/libgcc_s.so.1 (0x00007ff153d52000)
        libc.so.6 => /lib64/libc.so.6 (0x00007ff153984000)
        /lib64/ld-linux-x86-64.so.2 (0x00007ff16ef39000)
        libgomp-a34b3233.so.1 => /tmp/tmp.3M50Q7bV1d/venv/lib/python3.7/site-packages/torchopt/_lib/../../torch/lib/libgomp-a34b3233.so.1 (0x00007ff15375a000)

Resolved by linking built shared library with torch/lib/libtorch_python.so only. Previously, we globbed all *.so libs and link them with our extension.

unset(TORCH_LIBRARIES)

foreach(VAR_PATH ${TORCH_LIBRARY_PATH})
-   file(GLOB TORCH_LIBRARY "${VAR_PATH}/*.so")
-   list(APPEND TORCH_LIBRARIES "${TORCH_LIBRARY}")
+   list(APPEND TORCH_LIBRARIES "${VAR_PATH}/libtorch_python.so")
endforeach()

target_link_libraries(
    <target> PRIVATE
    ${TORCH_LIBRARIES}
)