Core dumped when training, no error messages

I found another thread I made in the past Random core dumps and segmentation fault - #2 by ptrblck when I was unable to reproduce the problem, but as this is a long time after and seems like a totally different issue, I will make a new thread.

I am using pytorch’s higher GitHub - facebookresearch/higher: higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps. for some meta learning and I get a core dump every time I run the code. The code does go to NaN sometimes so there is some other problem I should fix, but this core dump still should not happen I believe.

torch: 1.8.1+cu111
higher: 0.2.1 (I believe)

Any idea what could be causing this?

Thread 32 "python" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7ffe40a80700 (LWP 1557747)]
0x00007fff94b7d48f in at::cuda::getCurrentCUDABlasHandle() () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
(gdb) bt
#0  0x00007fff94b7d48f in at::cuda::getCurrentCUDABlasHandle() () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#1  0x00007fff94b6b7dd in void at::cuda::blas::gemm<float>(char, char, long, long, long, float, float const*, long, float const*, long, float, float*, long) ()
   from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#2  0x00007fff95d878f0 in at::native::(anonymous namespace)::addmm_out_cuda_impl(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) ()
   from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#3  0x00007fff95d89092 in at::native::mm_cuda(at::Tensor const&, at::Tensor const&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#4  0x00007fff94bc2848 in at::(anonymous namespace)::(anonymous namespace)::wrapper_mm(at::Tensor const&, at::Tensor const&) ()
   from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#5  0x00007fff94bc288f in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#6  0x00007fff8303e596 in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const [clone .constprop.4808] () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#7  0x00007fff83043bbf in at::mm(at::Tensor const&, at::Tensor const&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#8  0x00007fff8492791c in torch::autograd::VariableType::(anonymous namespace)::mm(at::Tensor const&, at::Tensor const&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#9  0x00007fff84927e8f in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#10 0x00007fff834907d6 in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const [clone .constprop.1487] () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#11 0x00007fff834949ff in at::Tensor::mm(at::Tensor const&) const () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007fff85366911 in torch::autograd::generated::details::mm_mat1_backward(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::Scalar const&) ()
   from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007fff8473c093 in torch::autograd::generated::AddmmBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ()
   from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007fff84df9771 in torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007fff84df557b in torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007fff84df619f in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) () from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007fff84ded979 in torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#18 0x00007ffff5571293 in torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/jeff/.venv/env/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#19 0x00007ffff65c8d84 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#20 0x00007ffff7d9f609 in start_thread (arg=<optimized out>) at pthread_create.c:477
#21 0x00007ffff7edb293 in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

Do you have a way to reproduce the issue and the coredump creation, which you could share?

I was able to create a self contained example in this gist
Reproduction of an issue causing a core dump on Pytorch (1.8.1+cu111) and higher (0.2.1) · GitHub

I was able to update the gist and make it even shorter, with no dependence on the higher library

I have very similar issue. My code also segfaults at at::cuda::getCurrentCUDABlasHandle() and Jeff’s shorter snipped results in:

(gdb) bt
#0  0x00007fffe8d4e6b3 in cudart::contextStateManager::getRuntimeContextState(cudart::contextState**, bool) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
#1  0x00007fffe8d2abaa in cudart::cudaApiLaunchKernel(void const*, dim3, dim3, void**, unsigned long, CUstream_st*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
#2  0x00007fffe8d709b6 in cudaLaunchKernel () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
#3  0x00007fffe80f2e4c in void gemv2T_kernel_val<int, int, float, float, float, float, 128, 16, 2, 2, false, false, cublasGemvParams<cublasGemvTensorStridedBatched<float const>, cublasGemvTensorStridedBatched<float const>, cublasGemvTensorStridedBatched<float>, float> >(cublasGemvParams<cublasGemvTensorStridedBatched<float const>, cublasGemvTensorStridedBatched<float const>, cublasGemvTensorStridedBatched<float>, float>, float, float) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
#4  0x00007fffe8385b35 in gemm_matmulAlgo_gemv_strided<float, float, float>::run(gemmInternalParams_t const&, matmulAlgoConfig_t const&, void*) const () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
#5  0x00007fffe80c4035 in run_matmul_template(matmul_template_factory_key const&, gemmInternalParams_t&, cublasLtMatmulAlgoInternal_t const*, void*, unsigned long) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
#6  0x00007fffe80c1c4f in cublasLtMatmul () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
#7  0x00007fff9728a308 in cublasGemvExLt(cublasContext*, CUstream_st*, cublasOperation_t, int, int, void const*, void const*, cudaDataType_t, int, void const*, cudaDataType_t, int, void const*, void*, cudaDataType_t, int, cublasComputeType_t, bool, int, bool, long long, long long, long long) [clone .part.16] () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#8  0x00007fff97124687 in cublasSgemmRecursiveEntry(cublasContext*, int, int, int, int, int, float const*, float const*, int, float const*, int, float const*, float*, int) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#9  0x00007fff97125377 in cublasSgemm_v2 () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#10 0x00007fff94ffd8d9 in void at::cuda::blas::gemm<float>(char, char, long, long, long, float, float const*, long, float const*, long, float, float*, long) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#11 0x00007fff962198f0 in at::native::(anonymous namespace)::addmm_out_cuda_impl(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#12 0x00007fff9621b837 in at::native::addmm_out_cuda(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#13 0x00007fff9621b903 in at::native::addmm_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#14 0x00007fff9505aea7 in at::(anonymous namespace)::(anonymous namespace)::wrapper_addmm(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#15 0x00007fff9505af3e in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar), &at::(anonymous namespace)::(anonymous namespace)::wrapper_addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda_cu.so
#16 0x00007fff8369027c in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar)> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) const () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#17 0x00007fff834f7e21 in at::addmm(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#18 0x00007fff84c7924b in torch::autograd::VariableType::(anonymous namespace)::addmm(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#19 0x00007fff84c7988e in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar), &torch::autograd::VariableType::(anonymous namespace)::addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#20 0x00007fff8369027c in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar)> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) const () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#21 0x00007fff834f7e21 in at::addmm(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#22 0x00007fff83033dca in at::native::linear(at::Tensor const&, at::Tensor const&, at::Tensor const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#23 0x00007fff83834210 in at::(anonymous namespace)::(anonymous namespace)::wrapper_linear(at::Tensor const&, at::Tensor const&, at::Tensor const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#24 0x00007fff83856dad in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&), &c10::impl::detail::with_explicit_optional_tensors_<at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&), at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&), c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_linear> >::wrapper>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&> >, at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#25 0x00007fff836277a9 in at::linear(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so
#26 0x00007ffff58b0aee in torch::autograd::THPVariable_linear(_object*, _object*, _object*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so

info threads
  Id   Target Id                                            Frame
* 1    Thread 0x7ffff7bec740 (LWP 2919023) "python3"        0x00007fffe8d4e6b3 in cudart::contextStateManager::getRuntimeContextState(cudart::contextState**, bool) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so
  2    Thread 0x7fff09b97700 (LWP 2919034) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d12e0 <thread_status+96>) at ../sysdeps/nptl/futex-internal.h:183
  3    Thread 0x7fff09396700 (LWP 2919035) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1360 <thread_status+224>) at ../sysdeps/nptl/futex-internal.h:183
  4    Thread 0x7fff04b95700 (LWP 2919036) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d13e0 <thread_status+352>) at ../sysdeps/nptl/futex-internal.h:183
  5    Thread 0x7fff02394700 (LWP 2919037) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1460 <thread_status+480>) at ../sysdeps/nptl/futex-internal.h:183
  6    Thread 0x7ffeffb93700 (LWP 2919038) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d14e0 <thread_status+608>) at ../sysdeps/nptl/futex-internal.h:183
  7    Thread 0x7ffefd392700 (LWP 2919039) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1560 <thread_status+736>) at ../sysdeps/nptl/futex-internal.h:183
  8    Thread 0x7ffefab91700 (LWP 2919040) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d15e0 <thread_status+864>) at ../sysdeps/nptl/futex-internal.h:183
  9    Thread 0x7ffef8390700 (LWP 2919041) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1660 <thread_status+992>) at ../sysdeps/nptl/futex-internal.h:183
  10   Thread 0x7ffef5b8f700 (LWP 2919042) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d16e0 <thread_status+1120>) at ../sysdeps/nptl/futex-internal.h:183
  11   Thread 0x7ffef338e700 (LWP 2919043) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1760 <thread_status+1248>) at ../sysdeps/nptl/futex-internal.h:183
  12   Thread 0x7ffef0b8d700 (LWP 2919044) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d17e0 <thread_status+1376>) at ../sysdeps/nptl/futex-internal.h:183
  13   Thread 0x7ffeee38c700 (LWP 2919045) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1860 <thread_status+1504>) at ../sysdeps/nptl/futex-internal.h:183
  14   Thread 0x7ffeebb8b700 (LWP 2919046) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d18e0 <thread_status+1632>) at ../sysdeps/nptl/futex-internal.h:183
  15   Thread 0x7ffee938a700 (LWP 2919047) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1960 <thread_status+1760>) at ../sysdeps/nptl/futex-internal.h:183
  16   Thread 0x7ffee6b89700 (LWP 2919048) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d19e0 <thread_status+1888>) at ../sysdeps/nptl/futex-internal.h:183
  17   Thread 0x7ffee4388700 (LWP 2919049) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1a60 <thread_status+2016>) at ../sysdeps/nptl/futex-internal.h:183
  18   Thread 0x7ffee1b87700 (LWP 2919050) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1ae0 <thread_status+2144>) at ../sysdeps/nptl/futex-internal.h:183
  19   Thread 0x7ffedf386700 (LWP 2919051) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1b60 <thread_status+2272>) at ../sysdeps/nptl/futex-internal.h:183
  20   Thread 0x7ffedcb85700 (LWP 2919052) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1be0 <thread_status+2400>) at ../sysdeps/nptl/futex-internal.h:183
  21   Thread 0x7ffeda384700 (LWP 2919053) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1c60 <thread_status+2528>) at ../sysdeps/nptl/futex-internal.h:183
  22   Thread 0x7ffed7b83700 (LWP 2919054) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1ce0 <thread_status+2656>) at ../sysdeps/nptl/futex-internal.h:183
  23   Thread 0x7ffed5382700 (LWP 2919055) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1d60 <thread_status+2784>) at ../sysdeps/nptl/futex-internal.h:183
  24   Thread 0x7ffed2b81700 (LWP 2919056) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x7fff0c2d1de0 <thread_status+2912>) at ../sysdeps/nptl/futex-internal.h:183
  25   Thread 0x7ffecbcb8700 (LWP 2919057) "python3"        0x00007ffff7edbc90 in accept4 (fd=17, addr=..., addr_len=0x7ffecbcb7b74, flags=524288) at ../sysdeps/unix/sysv/linux/accept4.c:32
  26   Thread 0x7ffecb4b7700 (LWP 2919058) "cuda-EvtHandlr" 0x00007ffff7ecdaff in __GI___poll (fds=0x7ffea8000bc0, nfds=8, timeout=100) at ../sysdeps/unix/sysv/linux/poll.c:29
  27   Thread 0x7ffecacb6700 (LWP 2919059) "python3"        futex_abstimed_wait_cancelable (private=<optimized out>, abstime=0x7ffecacb5e80, clockid=<optimized out>, expected=0, futex_word=0x341d910) at ../sysdeps/nptl/futex-internal.h:320
  28   Thread 0x7ffec9499700 (LWP 2919062) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x676315fc) at ../sysdeps/nptl/futex-internal.h:183
  29   Thread 0x7ffec8c98700 (LWP 2919063) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x67631678) at ../sysdeps/nptl/futex-internal.h:183
  30   Thread 0x7ffec23bf700 (LWP 2919064) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x676316f8) at ../sysdeps/nptl/futex-internal.h:183
  31   Thread 0x7ffec1bbe700 (LWP 2919065) "python3"        futex_wait_cancelable (private=<optimized out>, expected=0, futex_word=0x67631778) at ../sysdeps/nptl/futex-internal.h:183

nvidia-drivers: 460.73.01
python: 3.8.5
torch: 1.8.1+cu111

What’s more interesting is that this problem occurs only on my old Maxwell GPU, on Turing and Ampere cards there’s no segfault (even if I use the same PyTorch version and the same nvidia drivers).

@deltaskelta Which GPU are you using? I’m sure I was able to reproduce this issue, but cannot run into it at the moment.

Im using a GeForce GTX 1080 Ti

Update:
Maxwell, Pascal GPUs - segfault
Turing, Ampere GPUs - OK

Also:

  • 1.8.0+cu111 - segfault
  • 1.8.1+cu111 - segfault
  • 1.8.1+cu102 - OK

Hi, I’m also facing this same error on my GTX 1050. I’m using pytorch version 1.8.1+cu111 with driver version 460.39 and python version 3.8.5.
I ran the github gist program above and here is the gdb output :

Thread 12 "python3" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7ffecbfff700 (LWP 73486)]
0x00007fff9466648f in at::cuda::getCurrentCUDABlasHandle() () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
(gdb) where
#0  0x00007fff9466648f in at::cuda::getCurrentCUDABlasHandle() () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#1  0x00007fff946547dd in void at::cuda::blas::gemm<float>(char, char, long, long, long, float, float const*, long, float const*, long, float, float*, long)
    () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#2  0x00007fff958708f0 in at::native::(anonymous namespace)::addmm_out_cuda_impl(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#3  0x00007fff95872092 in at::native::mm_cuda(at::Tensor const&, at::Tensor const&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#4  0x00007fff946ab848 in at::(anonymous namespace)::(anonymous namespace)::wrapper_mm(at::Tensor const&, at::Tensor const&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#5  0x00007fff946ab88f in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#6  0x00007fff82b27596 in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const [clone .constprop.4808] ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#7  0x00007fff82b2cbbf in at::mm(at::Tensor const&, at::Tensor const&) () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#8  0x00007fff8441091c in torch::autograd::VariableType::(anonymous namespace)::mm(at::Tensor const&, at::Tensor const&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#9  0x00007fff84410e8f in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#10 0x00007fff82f797d6 in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const [clone .constprop.1487] ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#11 0x00007fff82f7d9ff in at::Tensor::mm(at::Tensor const&) const () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007fff84e4f911 in torch::autograd::generated::details::mm_mat1_backward(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::Scalar const&) () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007fff84225093 in torch::autograd::generated::AddmmBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007fff848e2771 in torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007fff848de57b in torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007fff848df19f in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) ()
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007fff848d6979 in torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
--Type <RET> for more, q to quit, c to continue without paging--c
   from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#18 0x00007ffff505a293 in torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) () from /home/joybanerjee/.local/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#19 0x00007ffff6071d84 in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#20 0x00007ffff7d98609 in start_thread (arg=<optimized out>) at pthread_create.c:477
#21 0x00007ffff7ed4293 in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

Is there any solution for this ?

A workaround could be to use the conda binaries, as it seems that the pip wheels have issues with sm_61. We are currently looking into it, but are unsure about the root cause.

I’m having the same issue and can confirm part of the analysis by @Pawel_Wiejacha as the SegFault happens on a Quadro P520 (Pascal architecture) with Pytorch 1.8.0 and Cuda 11.1. PyTorch is running in a virtualenv with Python 3.8.10. The gdb trace is appended below.

Some of the libraries I use aren’t available for conda, I tried creating the corresponding recipes, but ran into some upstream errors with the required dependencies. Therefor changing from wheels (available for all required libraries) to conda packages (not available for six of them) would come with considerable effort.

@ptrblck Any news on the issue? Or some other workaround maybe? (and btw thanks for your great work in the forums answering a lot of more or less coherent questions)

#0  0x00007fff951fb48f in at::cuda::getCurrentCUDABlasHandle() () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#1  0x00007fff951e97dd in void at::cuda::blas::gemm<float>(char, char, long, long, long, float, float const*, long, float const*, long, float, float*, long) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#2  0x00007fff964058f0 in at::native::(anonymous namespace)::addmm_out_cuda_impl(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar, c10::Scalar) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#3  0x00007fff96407092 in at::native::mm_cuda(at::Tensor const&, at::Tensor const&) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#4  0x00007fff95240848 in at::(anonymous namespace)::(anonymous namespace)::wrapper_mm(at::Tensor const&, at::Tensor const&) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#5  0x00007fff9524088f in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&) () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so
#6  0x00007fff836be1b6 in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const [clone .constprop.4808] () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#7  0x00007fff836c37df in at::mm(at::Tensor const&, at::Tensor const&) () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#8  0x00007fff84fa750c in torch::autograd::VariableType::(anonymous namespace)::mm(at::Tensor const&, at::Tensor const&) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#9  0x00007fff84fa7a7f in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&) () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#10 0x00007fff83b103c6 in at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const [clone .constprop.1487] () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#11 0x00007fff83b145ef in at::Tensor::mm(at::Tensor const&) const () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007fff859e7421 in torch::autograd::generated::details::mm_mat2_backward(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::Scalar const&) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007fff84dd37ae in torch::autograd::generated::MmBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007fff85479361 in torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007fff8547516b in torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) () from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007fff85475d8f in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007fff8546d569 in torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#18 0x00007ffff5bef213 in torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/mbeeking/.local/share/virtualenvs/traveltime-prediction-UXvvem08/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#19 0x00007ffff6c043c4 in std::execute_native_thread_routine (__p=0x5555c5f5a330) at /build/gcc/src/gcc/libstdc++-v3/src/c++11/thread.cc:82
#20 0x00007ffff7f4a259 in start_thread () from /usr/lib/libpthread.so.0
#21 0x00007ffff7d235e3 in clone () from /usr/lib/libc.so.6

PyTorch 1.9.0 fixed the issue in the pip wheel (it wasn’t observed in the conda binary) so you would have to update to one of them.
If you depend on PyTorch 1.8.0, you could build it from source, since the issue was caused by splitting the library (necessary for static linking of CUDA, cudnn, etc.).

Although switching to torch 1.9.0 introduced a bunch of different problems at first I was able to sort these out. Thank you for reassuring me the update would be worth the trouble.

1 Like