`cuda.h' missing during torch.compile in new environment

I want to compare the newly released modernBERT model to roBERTa using Huggingface’s transformers. Since modernBERT is so new, they ask you to install the git version of transformers. For this reason, I am trying to create a new conda environment on our cluster. Below is the install script I’ve made:

conda create --name nlp2 -c conda-forge -c nvidia python=3.12 einops pip scipy numpy jupyter tqdm nltk numba pandas bitsandbytes trl peft accelerate xformers ipython ffmpeg black matplotlib spacy transformers scikit-learn scikit-image  nvidia::cuda-nvcc  triton pytorch torchvision torchaudio pytorch-gpu
conda activate nlp2
python -c "import torch; print('testing torch: ',torch.cuda.is_available())"
pip install git+https://github.com/huggingface/transformers.git
pip install flash_attn

Using this environment, I then run:

import torch
import transformers
import flash_attn
device = torch.device("cuda")
pipe = transformers.pipeline("fill-mask", model="roberta-base",torch_dtype=torch.bfloat16,device=device)
pipe2 = transformers.pipeline("fill-mask", model="answerdotai/ModernBERT-base",torch_dtype=torch.bfloat16,device=device)

print(pipe(" Hello I am a "+pipe.tokenizer.mask_token+" model."))
print(pipe2(" Hello I am a "+pipe2.tokenizer.mask_token+" model."))

This fails on the last line, during inference, with the following error citing torch.compile missing cuda.h. There is a similar topic here, but the suggested solution there does not seem to help in my case. There is no “cuda.h” file in the corresponding /include directory ( for me in ~/miniconda3/envs/nlp2/include/):

tmp/251425/tmpzndjuzee/main.c:1:10: fatal error: cuda.h: No such file or directory
    1 | #include "cuda.h"
      |          ^~~~~~~~
compilation terminated.
Traceback (most recent call last):
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/__init__.py", line 2234, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1324, in load
    compiled_graph = FxGraphCache._lookup_graph(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1098, in _lookup_graph
    graph.current_callable = PyCodeCache.load_by_key_path(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 2876, in load_by_key_path
    mod = _reload_python_module(key, path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/251425/torchinductor_korchins/ch/cchdfc6hl2235kohy2kpamiocwugicwxzyoowz547ze6ki6lpvef.py", line 48, in <module>
    triton_red_fused_embedding_native_layer_norm_0 = async_compile.triton('triton_', '''
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/async_compile.py", line 203, in triton
    kernel.precompile()
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 452, in _precompile_config
    binary._init_handles()
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/compiler/compiler.py", line 368, in _init_handles
    device = driver.active.get_current_device()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/runtime/driver.py", line 23, in __getattr__
    self._initialize_obj()
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/runtime/driver.py", line 20, in _initialize_obj
    self._obj = self._init_fn()
                ^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/runtime/driver.py", line 9, in _create_driver
    return actives[0]()
           ^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 371, in __init__
    self.utils = CudaUtils()  # TODO: make static
                 ^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 80, in __init__
    mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 57, in compile_module_from_src
    so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/runtime/build.py", line 48, in _build
    ret = subprocess.check_call(cc_cmd)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/subprocess.py", line 413, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/home/korchins/miniconda3/envs/nlp2/bin/x86_64-conda-linux-gnu-cc', '/tmp/251425/tmpzndjuzee/main.c', '-O3', '-shared', '-fPIC', '-o', '/tmp/251425/tmpzndjuzee/cuda_utils.cpython-312-x86_64-linux-gnu.so', '-lcuda', '-L/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/backends/nvidia/lib', '-L/lib64', '-I/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/backends/nvidia/include', '-I/tmp/251425/tmpzndjuzee', '-I/home/korchins/miniconda3/envs/nlp2/include/python3.12']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/pipelines/fill_mask.py", line 270, in __call__
    outputs = super().__call__(inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1301, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1308, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1208, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/pipelines/fill_mask.py", line 127, in _forward
    model_outputs = self.model(**model_inputs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1059, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 895, in forward
    hidden_states = self.embeddings(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 210, in forward
    self.compiled_embeddings(input_ids)
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
    self._return(inst)
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
    self.output.compile_subgraph(
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CalledProcessError: Command '['/home/korchins/miniconda3/envs/nlp2/bin/x86_64-conda-linux-gnu-cc', '/tmp/251425/tmpzndjuzee/main.c', '-O3', '-shared', '-fPIC', '-o', '/tmp/251425/tmpzndjuzee/cuda_utils.cpython-312-x86_64-linux-gnu.so', '-lcuda', '-L/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/backends/nvidia/lib', '-L/lib64', '-I/home/korchins/miniconda3/envs/nlp2/lib/python3.12/site-packages/triton/backends/nvidia/include', '-I/tmp/251425/tmpzndjuzee', '-I/home/korchins/miniconda3/envs/nlp2/include/python3.12']' returned non-zero exit status 1.

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Sorry for the noise, but in case anyone else stumbles into this problem, the solution (for me) was simple:
conda install conda-forge::torchtriton