I’m trying to run https://github.com/karpathy/nanoGPT/train.py on an RTX 4090 / ADA card. torch.compile(model)
fails with the following error:
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/fileelTjxJ, line 6; error : PTX .version 7.4 does not support .target sm_89
ptxas fatal : Ptx assembly aborted due to errors
My environment:
Collecting environment information...
PyTorch version: 2.0.0.dev20230119
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.10 (x86_64)
GCC version: (Ubuntu 12.2.0-3ubuntu1) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36
Python version: 3.10.0 (default, Mar 3 2022, 09:58:08) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.19.0-29-generic-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA Graphics Device
Nvidia driver version: 520.61.05
cuDNN version: Probably one of the following:
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so.8.7.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.7.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.7.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.7.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.7.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.7.0
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20230119
[pip3] torchaudio==2.0.0.dev20230119
[pip3] torchvision==0.15.0.dev20230119
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.23.5 py310hd5efca6_0
[conda] numpy-base 1.23.5 py310h8e6c178_0
[conda] pytorch 2.0.0.dev20230119 py3.10_cuda11.8_cudnn8.5.0_0 pytorch-nightly
[conda] pytorch-cuda 11.8 h8dd9ede_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 2.0.0.dev20230119 py310_cu118 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py310 pytorch-nightly
[conda] torchvision 0.15.0.dev20230119 py310_cu118 pytorch-nightly
ptxas --version
returns
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
I’m not sure, where PTX v7.4 comes into play here and how to update / what to re-compile to enable sm_89 / ADA. If I understand correctly, everything in the environment should be CUDA 11.8 and that should come with PTX 7.8 and be sm_89 compatible.
Any recommendations?
Full stack trace:
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/concurrent/futures/process.py", line 243, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 533, in _worker_compile
kernel.precompile(warm_cache_only_with_cc=cc)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 59, in precompile
self.launchers = [
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 60, in <listcomp>
self._precompile_config(c, warm_cache_only_with_cc)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 73, in _precompile_config
triton.compile(
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/triton/compiler.py", line 1256, in compile
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/triton/compiler.py", line 901, in _compile
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/fileSePVLQ, line 6; error : PTX .version 7.4 does not support .target sm_89
ptxas fatal : Ptx assembly aborted due to errors
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 674, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/__init__.py", line 1264, in __call__
return self.compile_fn(model_, inputs_)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
return fn(gm, example_inputs, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
return compile_fx(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 411, in compile_fx
return aot_autograd(
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2453, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2150, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1412, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1062, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 386, in fw_compiler
return inner_compile(
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 586, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/debug.py", line 224, in inner
return fn(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 153, in compile_fx_inner
compiled_fn = graph.compile_to_fn()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/graph.py", line 545, in compile_to_fn
return self.compile_to_module().call
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/graph.py", line 534, in compile_to_module
mod = PyCodeCache.load(code)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 504, in load
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_jrahn/3c/c3cmse7l372boit76z5ugnr2v7pxwsof5xkbmih2v5f77zwl2n4e.py", line 1095, in <module>
async_compile.wait(globals())
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 699, in wait
scope[key] = result.result()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 557, in result
self.future.result()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/concurrent/futures/_base.py", line 445, in result
return self.__get_result()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
raise self._exception
RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/fileSePVLQ, line 6; error : PTX .version 7.4 does not support .target sm_89
ptxas fatal : Ptx assembly aborted due to errors
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/jrahn/dev/nanoGPT/train.py", line 223, in <module>
losses = estimate_loss()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/jrahn/dev/nanoGPT/train.py", line 184, in estimate_loss
logits, loss = model(X, Y)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
return fn(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
return fn(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 96, in time_wrapper
r = func(*args, **kwargs)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
return _compile(
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
out_code = transform_code_object(code, transform)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 387, in transform
tracer.run()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
super().run()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
and self.step()
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
getattr(self, inst.opname)(inst)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
self.output.compile_subgraph(self)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 551, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 598, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/jrahn/miniconda3/envs/pt2/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 679, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/fileSePVLQ, line 6; error : PTX .version 7.4 does not support .target sm_89
ptxas fatal : Ptx assembly aborted due to errors
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True