Torch2.0 torch.compile(model) error:

The recently released Torch 2.0 is a great work, so I wanted to give it a try

import torch
import torchvision

m = torchvision.models.resnet50().cuda()
mm = torch.compile(m)

data = torch.rand(1, 3, 224, 224).cuda()
o = mm(data)

but error was occur:

/usr/local/anaconda3/lib/python3.7/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory
  warn(f"Failed to load image Python extension: {e}")
/usr/local/anaconda3/lib/python3.7/site-packages/torch/_dynamo/eval_frame.py:367: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
  "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled."
Traceback (most recent call last):
  File "test.py", line 9, in <module>
    o = mm(data)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py", line 284, in forward
    def forward(self, x: Tensor) -> Tensor:
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py", line 2343, in forward
    return compiled_fn(full_args)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py", line 887, in g
    return f(*args)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py", line 1906, in debug_compiled_function
    return compiled_function(*args)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py", line 1718, in compiled_function
    all_outs = CompiledFunction.apply(*args_with_synthetic_bases)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/autograd/function.py", line 419, in apply
    return super().apply(*args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py", line 1584, in forward
    disable_amp=disable_amp,
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py", line 912, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_inductor/compile_fx.py", line 199, in run
    return model(new_inputs)
  File "/tmp/torchinductor_root/4r/c4rjyj24xhy4dcta2noqlpecsq3ewxmrtw543abb7vzx47ah2yj5.py", line 2313, in call
    triton_fused_convolution_mean_var_var_1_2.run(buf0, buf2, buf3, buf5, buf7, 128, 6272, grid=grid(128), stream=stream0)
  File "/usr/local/anaconda3/lib/python3.7/site-packages/torch/_inductor/triton_ops/autotune.py", line 180, in run
    stream=stream,
  File "<string>", line 6, in launcher
RuntimeError: Triton Error [CUDA]: invalid argument
*** Error in `python': munmap_chunk(): invalid pointer: 0x00007f72b10c03e9 ***

my environment is:

cuda: 11.7
python: 3.9
gpu: a100
cuda driver version: 470.57.02

I set the environment follow the document: https://pytorch.org/get-started/pytorch-2.0/#requirements

2 Likes

I cannot reproduce the issue using a latest PyTorch nightly release on an A100 and a local CUDA11.7 toolkit. Adding a print(o.sum()) yields:

python tmp.py 
tensor(-9.5465, device='cuda:0', grad_fn=<SumBackward0>)

Env:

python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.0.dev20221214+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-80-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB
...

Thanks for your quick reply, this is my enviorment:

Collecting environment information...
PyTorch version: 2.0.0.dev20221214+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 7.5.0
Clang version: 3.4.2 (tags/RELEASE_34/dot2-final)
CMake version: version 3.25.0
Libc version: glibc-2.17

Python version: 3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 15:55:03)  [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-3.10.0-1127.19.1.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 470.57.02
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.0.5
/usr/lib64/libcudnn_adv_infer.so.8.0.5
/usr/lib64/libcudnn_adv_train.so.8.0.5
/usr/lib64/libcudnn_cnn_infer.so.8.0.5
/usr/lib64/libcudnn_cnn_train.so.8.0.5
/usr/lib64/libcudnn_ops_infer.so.8.0.5
/usr/lib64/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.24.0rc2
[pip3] torch==2.0.0.dev20221214+cu117
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.14.0
[conda] numpy                     1.24.0rc2                pypi_0    pypi
[conda] torch                     2.0.0.dev20221214+cu117          pypi_0    pypi
[conda] torchtriton               2.0.0+0d7e753227          pypi_0    pypi
[conda] torchvision               0.14.0                   pypi_0    pypi

I don’t know how to fix the bug. By the way, I also try torch container :

docker pull ghcr.io/pytorch/pytorch-nightly

But it give the same error

Sorry to bother you, I still haven’t succeeded in running a pytorch 2.0 program, either manually installing torch2.0 or using a container, could you help me out? grateful

Which driver did you install, which solved the issue?
Sorry for no updates, but I wasn’t able to reproduce the issue using different setups.
Based on your initial post you were using 470.57.02 which is also installed on my A100 node.

That’s so strange. I saw that the driver version required by cuda 11.7 should be greater than 515. I thought it was a driver problem, so I am currently upgrading the driver version. If you are also using the 470 version of the driver, then I don’t know how to fix this problem.

I thought you’ve already updated the driver and fixed the problem or was this just an idea?
The “native” driver for CUDA 11.7 would be 515, but older ones would still support “minor-version compatibility” and should thus not break the use case.
I was thinking if there might be a mismatch between the ptxas version installed on your system, which might be causing the issue. Do you also use a locally installed CUDA 11.7 toolkit, an older one, or multiple installs? I.e. what does ptxas --version return?

That’s just an idea, if it doesn’t work I will modify my reply. Below is the output of ptxas --version

# ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_May__3_18:49:03_PDT_2022
Cuda compilation tools, release 11.7, V11.7.64
Build cuda_11.7.r11.7/compiler.31294372_0

The attempts I have made are:

  1. Using official provided docker: docker pull ghcr.io/pytorch/pytorch-nightly. The docker image lacks many packages, and after installing them one by one, it still cannot run normally

  2. build image from nvcr.io/nvidia/pytorch:22.05-py3, install torch2.0 via: pip3 install numpy --pre torch --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117

  3. install pytorch in my older cuda11.2 docker image, and reinstall cuda11.7 and torch2.0

All of the above attempts did not work.

By the way, the attempts 2&3 can pass the python verify_dynamo.py verification

# python verify_dynamo.py
Python version: 3.8.13
`torch` version: 2.0.0.dev20221220+cu117
CUDA version: 11.7

/opt/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
  warnings.warn(
/opt/conda/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /opt/conda/lib/python3.8/site-packages/torchvision/image.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
  warn(f"Failed to load image Python extension: {e}")
/opt/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
  warnings.warn(
/opt/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
  warnings.warn(
All required checks passed

If driver version is not matter, then I don’t know what’s the problem

Okey cuda driver version is not matter, I deleted previous comment

Finally, I found the reason is mismatch between torchvision and torch2.0. My old torchvision is 0.11. It gets work while update torchvision to 0.15. Now everything is going well.

Let me make a complaint, the official torch2.0 image cannot run normally, it is really painful

Thanks for the update and that’s interesting to hear as I cannot even install this older torchvision==0.11.0 version when the nightly PyTorch binary is installed:

$ pip install torchvision==0.11.0
Collecting torchvision==0.11.0
  Downloading torchvision-0.11.0-cp38-cp38-manylinux1_x86_64.whl (23.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.3/23.3 MB 58.4 MB/s eta 0:00:00
Requirement already satisfied: numpy in ./miniforge3/envs/tmp/lib/python3.8/site-packages (from torchvision==0.11.0) (1.23.5)
ERROR: Could not find a version that satisfies the requirement torch==1.10.0+cu102 (from torchvision) (from versions: 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1)
ERROR: No matching distribution found for torch==1.10.0+cu102
$ pip list | grep torch
torch                    2.0.0.dev20221222+cu117
torchtriton              2.0.0+0d7e753227

I’ve manually downloaded the wheel and installed it via --no-deps (which I would not recommend) to check if this would reproduce the issue and indeed I see a segfault now.

2 Likes

Yes, I thought at first that torchvision was just a modeling tool and would not affect the use of torch.compile, but it did cause segment fault in my environment