Hi @ all,
the recently released compilation functionality in PyTorch 2.x with torch._dynamo appears to be great work and I wanted to try the speedups for my U-Net model. However, I cannot get a compiled/optimized function or module on my machine due to an error. At the end of a very deep stack trace, the compilation fails and falls back to eager mode model with no speedups due to:
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING] File "/home/user/anaconda3/envs/torchnightly/lib/python3.11/site-packages/triton/compiler/compiler.py", line 589, in _init_handles
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING] mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[2023-07-29 04:04:26,800] torch._dynamo.convert_frame: [WARNING] RuntimeError: Triton Error [CUDA]: device kernel image is invalid
I tried to research the problem but could not find relevant information. Since even the very simple script below produces the error, I suspect a foundational compatibility or setup error?
import torch
import torch._dynamo
import logging
from torch._dynamo import config
config.verbose = True
# config.log_level = logging.INFO
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.norm = torch.nn.InstanceNorm3d(num_features=1)
self.conv = torch.nn.Conv3d(in_channels=1, out_channels=3, kernel_size=3)
self.activation = torch.nn.ReLU(inplace=True)
self.final = torch.nn.Conv3d(in_channels=3, out_channels=3, kernel_size=3)
def forward(self, x):
y = self.norm(x)
y = self.conv(y)
y = self.activation(y)
y = self.final(y)
return y
def main():
model = Model()
model = model.float()
device = torch.device('cuda:0')
model = model.to(device)
opt_model = torch.compile(model, backend='inductor')
input_shape = (1, 64, 64, 64)
x = torch.randn(input_shape, dtype=torch.float32, device=device)
result = opt_model(x)
print(result.shape)
if __name__ == '__main__':
main()
Maybe someone has a hint? Drivers, CUDA setup or something else? My environment is:
PyTorch version: 2.1.0.dev20230727
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31
Python version: 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.0-8-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB
Nvidia driver version: 460.91.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7452 32-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 2480.056
CPU max MHz: 3364.3550
CPU min MHz: 1500.0000
BogoMIPS: 4700.08
Virtualization: AMD-V
L1d cache: 2 MiB
L1i cache: 2 MiB
L2 cache: 32 MiB
L3 cache: 256 MiB
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Full AMD retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall sev_es fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca
Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.1.0.dev20230727
[pip3] torchaudio==2.1.0.dev20230727
[pip3] torchvision==0.16.0.dev20230727
[pip3] triton==2.1.0
[conda] blas 1.0 mkl
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
[conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly
[conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
[conda] cudatoolkit 11.8.0 h37601d7_11 conda-forge
[conda] filelock 3.9.0 py311_0 pytorch-nightly
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py311h9bf148f_0 pytorch-nightly
[conda] mkl_fft 1.3.1 py311hc796f24_0 pytorch-nightly
[conda] mkl_random 1.2.2 py311hbba84a0_0 pytorch-nightly
[conda] mpmath 1.2.1 py311_0 pytorch-nightly
[conda] numpy 1.24.3 py311hc206e33_0
[conda] numpy-base 1.24.3 py311hfd5febd_0
[conda] pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly
[conda] pysocks 1.7.1 py311_0 pytorch-nightly
[conda] pytorch 2.1.0.dev20230727 py3.11_cuda11.8_cudnn8.7.0_0 pytorch-nightly
[conda] pytorch-cuda 11.8 h7e8668a_5 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] requests 2.28.1 py311_0 pytorch-nightly
[conda] torchaudio 2.1.0.dev20230727 py311_cu118 pytorch-nightly
[conda] torchtriton 2.1.0+9e3e10c5ed py311 pytorch-nightly
[conda] torchvision 0.16.0.dev20230727 py311_cu118 pytorch-nightly
[conda] urllib3 1.26.14 py311_0 pytorch-nightly