Torch.compile regularly fails with Triton Error [CUDA]: device kernel image is invalid

Hi @ all,

the recently released compilation functionality in PyTorch 2.x with torch._dynamo appears to be great work and I wanted to try the speedups for my U-Net model. However, I cannot get a compiled/optimized function or module on my machine due to an error. At the end of a very deep stack trace, the compilation fails and falls back to eager mode model with no speedups due to:

[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING]   File "/home/user/anaconda3/envs/torchnightly/lib/python3.11/site-packages/triton/compiler/compiler.py", line 589, in _init_handles
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING]     mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING] RuntimeError: Triton Error [CUDA]: device kernel image is invalid

I tried to research the problem but could not find relevant information. Since even the very simple script below produces the error, I suspect a foundational compatibility or setup error?

import torch
import torch._dynamo
import logging

from torch._dynamo import config

config.verbose = True
# config.log_level = logging.INFO

class Model(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()

        self.norm = torch.nn.InstanceNorm3d(num_features=1)
        self.conv = torch.nn.Conv3d(in_channels=1, out_channels=3, kernel_size=3)
        self.activation = torch.nn.ReLU(inplace=True)
        self.final = torch.nn.Conv3d(in_channels=3, out_channels=3, kernel_size=3)

    def forward(self, x):
        y = self.norm(x)
        y = self.conv(y)
        y = self.activation(y)
        y = self.final(y)
        return y


def main():
    model = Model()
    model = model.float()
    device = torch.device('cuda:0')
    model = model.to(device)
    opt_model = torch.compile(model, backend='inductor')

    input_shape = (1, 64, 64, 64)
    x = torch.randn(input_shape, dtype=torch.float32, device=device)

    result = opt_model(x)
    print(result.shape)


if __name__ == '__main__':
    main()

Maybe someone has a hint? Drivers, CUDA setup or something else? My environment is:

PyTorch version: 2.1.0.dev20230727
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.0-8-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB

Nvidia driver version: 460.91.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   43 bits physical, 48 bits virtual
CPU(s):                          128
On-line CPU(s) list:             0-127
Thread(s) per core:              2
Core(s) per socket:              32
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7452 32-Core Processor
Stepping:                        0
Frequency boost:                 enabled
CPU MHz:                         2480.056
CPU max MHz:                     3364.3550
CPU min MHz:                     1500.0000
BogoMIPS:                        4700.08
Virtualization:                  AMD-V
L1d cache:                       2 MiB
L1i cache:                       2 MiB
L2 cache:                        32 MiB
L3 cache:                        256 MiB
NUMA node0 CPU(s):               0-31,64-95
NUMA node1 CPU(s):               32-63,96-127
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Full AMD retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall sev_es fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.1.0.dev20230727
[pip3] torchaudio==2.1.0.dev20230727
[pip3] torchvision==0.16.0.dev20230727
[pip3] triton==2.1.0
[conda] blas                      1.0                         mkl  
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly
[conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] cudatoolkit               11.8.0              h37601d7_11    conda-forge
[conda] filelock                  3.9.0                   py311_0    pytorch-nightly
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0           py311h9bf148f_0    pytorch-nightly
[conda] mkl_fft                   1.3.1           py311hc796f24_0    pytorch-nightly
[conda] mkl_random                1.2.2           py311hbba84a0_0    pytorch-nightly
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.24.3          py311hc206e33_0  
[conda] numpy-base                1.24.3          py311hfd5febd_0  
[conda] pillow                    9.3.0           py311h3fd9d12_2    pytorch-nightly
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.1.0.dev20230727 py3.11_cuda11.8_cudnn8.7.0_0    pytorch-nightly
[conda] pytorch-cuda              11.8                 h7e8668a_5    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] requests                  2.28.1                  py311_0    pytorch-nightly
[conda] torchaudio                2.1.0.dev20230727     py311_cu118    pytorch-nightly
[conda] torchtriton               2.1.0+9e3e10c5ed           py311    pytorch-nightly
[conda] torchvision               0.16.0.dev20230727     py311_cu118    pytorch-nightly
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly

hi, have you solve this error?

This issue should be fixed via Bundle PTXAS into 11.8 wheel · pytorch/builder@5c814e2 · GitHub and [RelEng] Define `BUILD_BUNDLE_PTXAS` by malfet · Pull Request #119750 · pytorch/pytorch · GitHub