Hi, thank you, I am planning to post a github issue with more details at some point for the major bug that I mentioned, but here is an exerpt of the exception if you are interested:
def forward(self, x):
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2819, in forward
return compiled_fn(full_args)
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1222, in g
return f(*args)
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2386, in debug_compiled_function
return compiled_function(*args)
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1898, in runtime_wrapper
all_outs = call_func_with_args(
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1247, in call_func_with_args
out = normalize_as_list(f(args))
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1222, in g
return f(*args)
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2151, in forward
fw_outs = call_func_with_args(
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1247, in call_func_with_args
out = normalize_as_list(f(args))
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 248, in run
return model(new_inputs)
File "/tmp/torchinductor_s334978/vq/cvqbjhfwpjhfqdbkmoh33tbbksqm6kubbszhrtvk645ztplupd2y.py", line 16458, in call
triton__193.run(buf2401, buf2403, buf2399, buf2381, buf2404, buf2406, 32, 512, grid=grid(32), stream=stream0)
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/torch/_inductor/triton_ops/autotune.py", line 190, in run
result = launcher(
File "<string>", line 6, in launcher
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/triton/compiler.py", line 1679, in __getattribute__
self._init_handles()
File "/home/s334978/miniconda3/envs/ma/lib/python3.10/site-packages/triton/compiler.py", line 1669, in _init_handles
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
RuntimeError: Triton Error [CUDA]: unknown error
This made the GPU show “Err!” when running nvidia-smi and further training was impossible after that until rebooting the server. This happened repeadetly usually after about 3 to 5 days of continouus hyperparameter optimization runs. I got quite paranoid about memory leaks and took measures to trigger garbage collection between each new training run of the optimization. It still happened and I wrote a bash script to monitor VRAM, GPU temperature and power draw as I wanted to rule out things. The script also dumped the ps output of currently running processes as soon as an error occurred and the output of nvidia-smi. There was nothing noticable there as far as I could tell. I also created a new conda env, reinstalled torch, and all other dependencies, which didn’t stop the crashes.
The other one that was incredibly frustrating for me is that in my tests, training with torch.compile could not be made deterministic (Reproducibility — PyTorch 2.0 documentation). I found out about this only after a few weeks of training so I had to repeat a lot of my training runs (I got the strict requirement that a certain subset of the training runs in my thesis should be reproducible). With compiling there was always differences in the metrics of two runs that should be identical. Without compiling the model beforehand, two runs with the same configuration were perfectly identical, same loss curves, same metrics etc. For this, I believe that there are some submitted issues on github already.
I am sorry for the slightly ranty nature of my post yesterday, I was in a very frustrated state as I am under a lot of pressure to finish my thesis and these issues really complicated things for me. In general trying to debug my problems and looking through github issues, I was just a bit surprised to find a large amount of quite critical sounding problems related to torch.compile, for example memory leaks.