PyTorch 2.0 distribution that uses cuda only if available?

Hey folks,

after upgrading to torch==2.0 yesterday, I have found that I am no longer able to run torch programs if the system doesn’t have CUDA.

Here’s my observation on the various distributions:

# PyTorch 2.0 for use with CUDA, CUDA libs get installed as pip deps if unavailable on the system. Wheel size w/o deps: ~620MB
pip3 install torch==2.0
# PyTorch 2.0 with bundled CUDA 11.7. Wheel size w/o deps: 1.8GB
pip3 install torch==2.0+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
# PyTorch 2.0 w/o CUDA support. Wheel size w/o deps: 195MB
pip3 install torch==2.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu

If I now look at what got installed from the first option (site-packages/torch/lib), I see, among other things:

-rwxrwxr-x 1 ubuntu ubuntu 487M Apr 11 08:33 libtorch_cpu.so
-rwxrwxr-x 1 ubuntu ubuntu 627M Apr 11 08:33 libtorch_cuda.so

so my expectation would be that this distribution allows me to use Torch with or without CUDA support.

However, in reality, import torch fails on a non-CUDA system (w/o the CUDA pip deps installed), because ldd libtorch_global_deps.so shows that the global deps library – which is unconditionally loaded at package import time – is linked against a bunch of CUDA libraries (libcublas.so, libcurand.so and others), which then fails to load on a non-CUDA system.

This is apparently different from the behavior in torch==1.11.0 (previous version I was using). Here, I also see

-rwxrwxr-x 1 ubuntu ubuntu 433M Apr 11 09:51 libtorch_cpu.so
-rwxrwxr-x 1 ubuntu ubuntu 994M Apr 11 09:50 libtorch_cuda.so

in the lib folder of the package, and I can indeed use CUDA on a CUDA-system, but the libtorch_global_deps.so does not link against any CUDA libraries:

$ ldd venv/lib/python3.9/site-packages/torch/lib/libtorch_global_deps.so | grep cuda
$

Does anyone have any insight into why this change was made? It makes it much much harder to use a consistent set of dependencies on various systems/architectures.

Now, this doesn’t matter if installing the right version of torch is regarded as a responsibility of the system, but in our case, we use bazel as our toolchain and thus require some level of hermetic homogeneity between building our production containers (using a base image with CUDA system libs) and, e.g., running basic functional tests in CI (on runners that don’t have GPU and where installing CUDA is a waste of time and space). Concretely, we don’t want to install over a GB of CUDA libraries as a Python dependency, because shipping them in a base image layer is more efficient. And we don’t want to install CUDA in the image used for our CI runner, because it will never have GPUs. But we do want to be able to run bazel test on a GPU-enabled, CUDA-enabled Linux system, have Torch use CUDA, without bazel resolving dependencies differently from CI under the hood.

I guess we could somehow address this at the Bazel level if there’s no other way, but why do we now have to, when in torch 1.11 we didn’t, while still enjoying CUDA support?

I don’t fully understand this claim. Are you installing the default PyTorch pip wheels with CUDA 11.7 dependencies (installed via CUDA pip wheels) hosted on PyPI and are deleting them afterwards manually?
Afterwards you are trying to import torch and are wondering why the import fails?

Without deleting libraries manually the CUDA-enabled wheels will still run:

CUDA_VISIBLE_DEVICES="" python -c "import torch; print(torch.cuda.is_available()); print(torch.__version__)"
False
2.0.0+cu117

However, if you want to install CPU-only wheels you could select them from the install matrix.

This is caused by a change in our build process since previously we were statically linking CUDA math libs into libtorch* so it was not directly visible as dependent files.
Besides that the same user experience was exposed: you can install the CUDA-enabled wheels on a CPU-only system and they will work (the default PyPI wheels). If you want to install the CPU-only wheels you could select them from the install matrix.

Could you describe what exactly is failing if you don’t manually delete dependencies?

Thanks for your response, you’re absolutely right. What I saw was actually a bug in the tool (based on pex) we use to create a Bazel lockfile out of a requirements.txt. It operates under the assumption that wheels for different platforms have all the same requirements, just annotated with environment markers, but that’s not the case for torch (as the requirements listed in the torch linux wheel metadata have linux/x86_64 environment markers, it wouldn’t hurt to have them in the darwin wheel as well, but they’re not there). However, I think the assumption of the tool is clearly too strict.

So the tool accidentally “deleted” the cuda deps, because they don’t occur in the darwin wheel. And this accidentally was the “right” thing for our scenario, as we ship our code in an image that has all required CUDA deps installed system-wide through the base image. But that right thing is clearly a hack given what the PyTorch team intended in terms of distribution flavors.

So I guess with one mystery solved, my questions can basically be condensed to the following:

  1. Is there a pre-compiled torch 2.0.0 package that neither bundles CUDA nor pulls in CUDA as Python dependencies, but instead uses CUDA libraries installed on the system?
  2. Same as 1, but with additionally gracefully falling back to CPU-only mode if the required CUDA libs are not present on the system?
  3. If the answer to either of the above is no, because this is a too uncommon setup, is there an option of creating such a package by compiling from (but without modifying the) source?

Besides that the same user experience was exposed: you can install the CUDA-enabled wheels on a CPU-only system and they will work (the default PyPI wheels). If you want to install the CPU-only wheels you could select them from the install matrix.

It still seems that there is a gap between the two options (CUDA-enabled wheels with CUDA libs as pip dependencies vs. CPU-only wheels): while both run on a GPU-less system, the former needs an extra GB or so of dependencies installed. Not just as a formal requirement, but will indeed not run without these deps, even though on those systems the same functionality is offered as by the CPU-only package that doesn’t need this dependency.

This is caused by a change in our build process since previously we were statically linking CUDA math libs into libtorch* so it was not directly visible as dependent files.

Hmm, I see, thanks for the explanation.

Hello, as a follow up to the posts above. Is it possible to get rid of the cpu related cuda library part (libtorch_cpu.so) if libtorch_cuda.so are present?
thx.

Could you describe what “cpu related cuda library part” means?
libtorch_cpu.so ships with needed symbols and you can check it via e.g. readelf.
Some symbols would be:

...
554637: 0000000007bf5650  2432 FUNC    GLOBAL DEFAULT    8 mkl_spblas_lp64_avx512_zcoo0stluc__svout_seq
554638: 000000000ac1a850  1888 FUNC    GLOBAL DEFAULT    8 mkl_blas_avx512_xcscal
554639: 00000000052d53a0  2872 FUNC    GLOBAL DEFAULT    8 torch::jit::tensorexpr::registerize(std::shared_ptr<torch::jit::tensorexpr::Stmt>)
554640: 0000000002a98ee0   208 FUNC    GLOBAL DEFAULT    8 at::cpu::_upsample_nearest_exact1d_symint_out(at::Tensor&, at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::optional<double>)
554641: 00000000025eb780  1492 FUNC    GLOBAL DEFAULT    8 at::_ops::random_from_out::call(at::Tensor const&, long, c10::optional<long>, c10::optional<at::Generator>, at::Tensor&)
554642: 000000000696cd75   265 FUNC    GLOBAL DEFAULT    8 xnn_run_elu_nc_f32
554643: 0000000016c9a480    40 OBJECT  GLOBAL DEFAULT   24 at::native::_segment_reduce_offsets_backward_stub
554644: 000000000e709200    64 FUNC    GLOBAL DEFAULT    8 mkl_dft_mc3_ippsFFTInitAlloc_C_32f
554645: 0000000002c1db50    16 FUNC    GLOBAL DEFAULT    8 at::compositeexplicitautograd::indices(at::Tensor const&)
554646: 0000000010a61680  1376 FUNC    GLOBAL DEFAULT    8 mkl_dft_def_dft_row_scopy_8
554647: 000000000d4307a0  1344 FUNC    GLOBAL DEFAULT    8 mkl_blas_avx2_ctrsv_unn
554648: 00000000029701e0   526 FUNC    GLOBAL DEFAULT    8 at::_ops::index_add_dimname::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Dimname, at::Tensor const&, at::Tensor const&, c10::Scalar const&)
554649: 0000000002584990  1437 FUNC    GLOBAL DEFAULT    8 at::_ops::new_zeros::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>)
554650: 0000000004ea3090 11652 FUNC    GLOBAL DEFAULT    8 torch::jit::MutationRemover::createSpecialMappedOp(torch::jit::Node*)
554651: 0000000005182980   401 FUNC    GLOBAL DEFAULT    8 torch::jit::tensorexpr::nnc_aten_quantized_mul_scalar_out(long, void**, long*, long*, long*, signed char*, long, long*)
554652: 000000000adead90    16 FUNC    GLOBAL DEFAULT    8 mkl_blas_mc3_cgemm_initialize_buffers
554653: 000000000e907c90  2064 FUNC    GLOBAL DEFAULT    8 mkl_dft_mc_ownsAdd_8u16u
554654: 0000000007fb08f0  1152 FUNC    GLOBAL DEFAULT    8 mkl_spblas_lp64_avx2_scoo1ntlnf__svout_seq
554655: 00000000053f1630    16 FUNC    WEAK   DEFAULT    8 std::_Sp_counted_ptr_inplace<torch::lazy::DeviceData, std::allocator<torch::lazy::DeviceData>, (__gnu_cxx::_Lock_policy)2>::_M_dispose()
554656: 000000000cec00b0  1456 FUNC    GLOBAL DEFAULT    8 mkl_blas_avx512_zherk_kernel_upper_b0
554657: 000000000b4f24d0  3200 FUNC    GLOBAL DEFAULT    8 mkl_vml_kernel_dExpI_Z0LAynn
554658: 00000000169b6d98    24 OBJECT  WEAK   DEFAULT   19 typeinfo for at::native::structured_index_copy_out

which are unrelated to CUDA as the library name suggests.

Hi @ptrblck thank you, I will have a look with readelf, but what I want to achieve is “easy” in the sense that I want to remove all cuda libraries (because of their size) that I dont need if I am certain I will be running pytorch on a GPU.

If you are certain you will be running PyTorch on a GPU, you won’t be able to remove any CUDA libraries. If you are not planning to use a GPU, install the CPU-only binaries.

1 Like

Thank you, that answers my question.