Hi,
Using PyTorch 2.4.0 under Linux on CPU, my model runs fine without torch.compile
With torch.compile
, I got the following error:
Traceback (most recent call last):
File "/home/laurent/project/main.py", line 279, in <module>
main(args)
File "/home/laurent/project/main.py", line 63, in main
train_fct(args=args,
File "/home/laurent/project/train.py", line 299, in train_nn_model
train_loss, val_loss, test_loss = train(args=args, trial=trial,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/project/train.py", line 469, in train
loss, batch_size = forward_backward_pass_model(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/project/train.py", line 880, in forward_backward_pass_model
predictions = model(inputs)
^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
super().run()
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
while self.step():
^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2642, in RETURN_VALUE
self._return(inst)
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2627, in _return
self.output.compile_subgraph(
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1098, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1318, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1409, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1390, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/__init__.py", line 1951, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1505, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 954, in aot_module_simplified
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 461, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1410, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 84, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/debug.py", line 304, in inner
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 527, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 831, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1749, in compile_to_fn
return self.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1699, in compile_to_module
mod = PyCodeCache.load_by_key_path(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3062, in load_by_key_path
mod = _reload_python_module(key, path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/yatmp/scr1/torchinductor_laurent/6z/c6zltwwwixsx6dshpxgxhzat44qrvx7dkcxxniu33cspr2inhgqs.py", line 499, in <module>
async_compile.wait(globals())
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 247, in wait
scope[key] = result.result()
^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3433, in result
return self.result_fn()
^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2654, in future
result = get_result()
^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2478, in load_fn
future.result()
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2504, in _worker_compile_cpp
compile_file(input_path, output_path, shlex.split(cmd))
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2360, in compile_file
raise exc.CppCompileError(cmd, output) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CppCompileError: C++ compile error
Command:
g++ /yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp -shared -fPIC -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -D_GLIBCXX_USE_CXX11_ABI=0 -I/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/include -I/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/include/TH -I/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/include/THC -I/home/laurent/python3.11_pytorch2.4/include/python3.11 -L/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/lib -L/home/laurent/python3.11_pytorch2.4/lib -L/home/laurent/python3.11_pytorch2.4/lib/python3.11/site-packages/torch/lib -ltorch -ltorch_cpu -lgomp -ltorch_python -lc10 -mavx2 -mfma -D CPU_CAPABILITY_AVX2 -O3 -DNDEBUG -ffast-math -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -march=native -fopenmp -D TORCH_INDUCTOR_CPP_WRAPPER -D C10_USING_CUSTOM_GENERATED_MACROS -o /yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.so
Output:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In function ‘T parse_arg(PyObject*, size_t) [with T = long int; PyObject = _object; size_t = long unsigned int]’:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:106:10: error: expected identifier before ‘[’ token
[[unlikely]] throw std::runtime_error("expected int arg");
^
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In lambda function:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:106:22: error: expected ‘{’ before ‘throw’
[[unlikely]] throw std::runtime_error("expected int arg");
^~~~~
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In function ‘T parse_arg(PyObject*, size_t) [with T = long int; PyObject = _object; size_t = long unsigned int]’:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:106:22: error: expected ‘;’ before ‘throw’
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In function ‘PyObject* kernel_py(PyObject*, PyObject*)’:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:115:14: error: expected identifier before ‘[’ token
[[unlikely]] throw std::runtime_error("tuple args required");
^
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In lambda function:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:115:26: error: expected ‘{’ before ‘throw’
[[unlikely]] throw std::runtime_error("tuple args required");
^~~~~
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In function ‘PyObject* kernel_py(PyObject*, PyObject*)’:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:115:26: error: expected ‘;’ before ‘throw’
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:117:14: error: expected identifier before ‘[’ token
[[unlikely]] throw std::runtime_error("requires 10 args");
^
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In lambda function:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:117:26: error: expected ‘{’ before ‘throw’
[[unlikely]] throw std::runtime_error("requires 10 args");
^~~~~
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp: In function ‘PyObject* kernel_py(PyObject*, PyObject*)’:
/yatmp/scr1/torchinductor_laurent/m4/cm4llt7tjnefqnqsjkj2tlli2ips4twifait3zzvm3ed2u6gy5it.cpp:117:26: error: expected ‘;’ before ‘throw’
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
I hope there is an obvious cure, otherwise I will try to reduce my code and post it.
Thanks for any help.
Collecting environment information…
PyTorch version: 2.4.0
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: SUSE Linux Enterprise Server 15 SP5 (x86_64)
GCC version: (SUSE Linux) 7.5.0
Clang version: Could not collect
CMake version: version 3.20.4
Libc version: glibc-2.31
Python version: 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:17) [GCC 12.2.0] (64-bit runtime)
Python platform: Linux-5.14.21-150500.55.52-default-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: GRID T4-2B
Nvidia driver version: 470.239.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 45 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 6
On-line CPU(s) list: 0-5
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6142 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 1
Socket(s): 6
Stepping: 4
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 6 MiB (6 instances)
L3 cache: 132 MiB (6 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-5
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-optimizer==2.11.0
[pip3] torch==2.4.0
[pip3] torch_geometric==2.4.0
[pip3] torchdiffeq==0.2.2
[pip3] torchmetrics==0.11.4
[conda] blas 1.0 mkl
[conda] mkl 2020.1 217
[conda] mkl-service 2.3.0
[conda] mkl_fft 1.1.0
[conda] mkl_random 1.1.1
[conda] numpy 1.18.5
[conda] numpy-base 1.18.5
[conda] numpy-stl 2.12.0
[conda] numpydoc 1.1.0