MultivariateNormal on GPU segmentation fault

I try to generate a distribution on gpu, but got segmentation fault.

Code is here:

from torch.distributions.multivariate_normal import MultivariateNormal
import torch
mean =  torch.ones(3).cuda()
scale = torch.ones(3).cuda()
mvn = MultivariateNormal(mean, torch.diag(scale))

I cannot reproduce this issue using the latest PyTorch release (1.7.1) on a TitanV and print(mvn) yields:

MultivariateNormal(loc: torch.Size([3]), covariance_matrix: torch.Size([3, 3]))

Which PyTorch and CUDA version as well as GPU are you using?

I met the same problem when running the code:

from torch.distributions.multivariate_normal import MultivariateNormal
import torch
mean =  torch.ones(3).cuda()
scale = torch.ones(3).cuda()
mvn = MultivariateNormal(mean, torch.diag(scale))

I got a Segmentation fault (core dumped).
My PyTorch version is 1.7.1, my GPU is RTX 3090, and the CUDA version is 11.0

I cannot reproduce this issue on an RTX3090 using the PyTorch 1.7.1 conda binaries with cudatoolkit=11.0:

# tmp.py
from torch.distributions.multivariate_normal import MultivariateNormal
import torch
print(torch.cuda.get_device_name(0))
print(torch.__version__)
print(torch.version.cuda)
mean =  torch.ones(3).cuda()
scale = torch.ones(3).cuda()
mvn = MultivariateNormal(mean, torch.diag(scale))
print(mvn)

Output:

GeForce RTX 3090
1.7.1
11.0
MultivariateNormal(loc: torch.Size([3]), covariance_matrix: torch.Size([3, 3]))

Thanks for your reply!
I ran the tmp.py script and get the following output:

GeForce RTX 3090
1.7.1
11.0
Segmentation fault (core dumped)

If mean and scale are CPU tensors, everything is fine.

GeForce RTX 3090
1.7.1
11.0
MultivariateNormal(loc: torch.Size([3]), covariance_matrix: torch.Size([3, 3]))

My python environment is created by conda and the python version is 3.6.

Could you check the backtrace via:

gdb --args python tmp.py
...
run
...
bt

and post it here?

Hi @ptrblck , I got the following output:

Starting program: /home/user/anaconda3/envs/vln/bin/python test.py
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
[New Thread 0x7fff17bd3700 (LWP 30501)]
GeForce RTX 3090
1.7.1
11.0
[New Thread 0x7fff173d2700 (LWP 30502)]
[New Thread 0x7fff16bd1700 (LWP 30503)]

Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007fffcfc35507 in ?? () from /usr/local/cuda-11.0/lib64/libcublasLt.so.11
(gdb) bt
#0  0x00007fffcfc35507 in ?? () from /usr/local/cuda-11.0/lib64/libcublasLt.so.11
#1  0x00007fffcfc36a24 in free_gemm_select () from /usr/local/cuda-11.0/lib64/libcublasLt.so.11
#2  0x00007fffda9fdaf5 in cublasDestroy_v2 () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/../../../../libcublas.so.11
#3  0x00007fff77b71380 in magma_queue_destroy_internal () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so
#4  0x00007fff77b91912 in magma_spotrf_LL_expert_gpu () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so
#5  0x00007fff77b919c4 in magma_spotrf_gpu () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so
#6  0x00007fff76460358 in void at::native::magmaCholesky<float>(magma_uplo_t, int, float*, int, int*) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so
#7  0x00007fff7646dfcc in at::native::_cholesky_helper_cuda(at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so
#8  0x00007fff75a34a8f in at::CUDAType::_cholesky_helper(at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so
#9  0x00007fffc7df8af0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, bool> >, at::Tensor (at::Tensor const&, bool)>::call(c10::OperatorKernel*, at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#10 0x00007fffc868c6db in at::Tensor c10::Dispatcher::callWithDispatchKey<at::Tensor, at::Tensor const&, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, bool)> const&, c10::DispatchKey, at::Tensor const&, bool) const () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#11 0x00007fffc8574dc7 in at::_cholesky_helper(at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007fffc7f1ccc3 in at::native::cholesky(at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007fffc874c2ff in at::TypeDefault::cholesky(at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007fffc7df8af0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, bool> >, at::Tensor (at::Tensor const&, bool)>::call(c10::OperatorKernel*, at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007fffc868c6db in at::Tensor c10::Dispatcher::callWithDispatchKey<at::Tensor, at::Tensor const&, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, bool)> const&, c10::DispatchKey, at::Tensor const&, bool) const () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007fffc8575207 in at::cholesky(at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007fffc9b371e5 in torch::autograd::VariableType::(anonymous namespace)::cholesky(at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#18 0x00007fffc7df8af0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, bool> >, at::Tensor (at::Tensor const&, bool)>::call(c10::OperatorKernel*, at::Tensor const&, bool) () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#19 0x00007fffc868c6db in at::Tensor c10::Dispatcher::callWithDispatchKey<at::Tensor, at::Tensor const&, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, bool)> const&, c10::DispatchKey, at::Tensor const&, bool) const () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#20 0x00007fffc87f8d07 in at::Tensor::cholesky(bool) const () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so
#21 0x00007fffcebd7905 in torch::autograd::THPVariable_cholesky () from /home/user/anaconda3/envs/vln/lib/python3.6/site-packages/torch/lib/libtorch_python.so
#22 0x0000555555666a14 in _PyCFunction_FastCallDict () at /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:231
#23 0x00005555556eea5c in call_function () at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:4851
#24 0x000055555571125a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:3335
#25 0x00005555556e8166 in _PyEval_EvalCodeWithName () at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:4166
#26 0x00005555556e932c in _PyFunction_FastCallDict () at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:5084
#27 0x0000555555666ddf in _PyObject_FastCallDict () at /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:2310
#28 0x000055555566b873 in _PyObject_Call_Prepend () at /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:2373
#29 0x000055555566681e in PyObject_Call () at /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:2261
#30 0x00005555556bf88b in slot_tp_init () at /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:6420
#31 0x00005555556eed97 in type_call () at /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:915
#32 0x0000555555666bfb in _PyObject_FastCallDict () at /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:2331
#33 0x00005555556eebae in call_function () at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:4875
#34 0x000055555571125a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:3335
#35 0x00005555556e9969 in _PyEval_EvalCodeWithName (qualname=0x0, name=<optimized out>, closure=0x0, kwdefs=0x0, defcount=0, defs=0x0, kwstep=2, kwcount=<optimized out>, kwargs=0x0, kwnames=0x0, argcount=0, args=0x0,  locals=0x7ffff7f3f0d8, globals=0x7ffff7f3f0d8, _co=0x7ffff7efaa50) at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:4166
#36 PyEval_EvalCodeEx () at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:4187
#37 0x00005555556ea70c in PyEval_EvalCode (co=co@entry=0x7ffff7efaa50, globals=globals@entry=0x7ffff7f3f0d8, locals=locals@entry=0x7ffff7f3f0d8) at /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:731
#38 0x000055555576a574 in run_mod () at /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:1025
#39 0x000055555576a971 in PyRun_FileExFlags () at /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:978
#40 0x000055555576ab73 in PyRun_SimpleFileExFlags () at /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:419
#41 0x000055555576ac7d in PyRun_AnyFileExFlags () at /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:81
#42 0x000055555576e663 in run_file (p_cf=0x7fffffffdecc, filename=0x5555558aa780 L"test.py", fp=0x5555558e7940) at /tmp/build/80754af9/python_1599604603603/work/Modules/main.c:340
#43 Py_Main () at /tmp/build/80754af9/python_1599604603603/work/Modules/main.c:811
#44 0x000055555563843e in main () at /tmp/build/80754af9/python_1599604603603/work/Programs/python.c:69
#45 0x00007ffff77e4bf7 in __libc_start_main (main=0x555555638350 <main>, argc=2, argv=0x7fffffffe0d8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffe0c8) at ../csu/libc-start.c:310
#46 0x0000555555717d0b in _start () at ../sysdeps/x86_64/elf/start.S:103

Thank you for the stack trace. That’s really helpful!
Could you install the nightly binaries as described on the website and see, if you are still running into this issue?
This PR should have fixed the out of bounds reads in cholesky, but is not available in 1.7.1.