Issue with torch.sparse.mm while running on GPU

I am trying to multiply a sparse matrix with a dense matrix using the torch.sparse.mm. It works fine on the CPU but fails at random iterations on the GPU. Below is the sample code to replicate the error.

import torch
torch.manual_seed(0)
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

device = torch.device(0 if torch.cuda.is_available() else "cpu")
# device = 'cpu'
noof_runs = 1000*500

for i in range(noof_runs):
    
    dense_size = torch.Size([1, torch.randint(1, 2000, (1,))])
    sparse_size = (torch.randint(1, 2000, (1,)).item(), 1)
    if i%1000==0:
        print(f' i :{i%1000} \n')
        print(i, sparse_size, dense_size)

    #sparse matrix
    value_tensor = torch.ones(1, dtype=torch.float64, device=device)
    index_tensor = torch.tensor([[0], [0]]).to(device).reshape(-1,1)
    sparse = torch.sparse_coo_tensor(index_tensor, value_tensor, sparse_size)

    #dense matrix 
    dense = torch.rand(dense_size, dtype=torch.float64).to(device)
    
    try:
        mat = torch.sparse.mm(sparse, dense)
    except:
        print(i, sparse_size, dense_size)
        raise
    

error

usr1@vapnik-SYS:~/home$ python test_sparse_mm.py 
 i :0 

0 (1, 1) torch.Size([1, 1202])
 i :0 

1000 (1815, 1) torch.Size([1, 1694])
1205 (956, 1) torch.Size([1, 1612])
Traceback (most recent call last):
  File "/home/test_sparse_mm.py", line 27, in <module>
    mat = torch.sparse.mm(sparse, dense)
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I am using pytorch-2.0.1 (py3.10_cuda11.8_cudnn8.7.0_0) using NVIDIA RTX4090 GPU. I was able to replicate the same error on a P100-based GPU using pytorch-2.0.1. I would appreciate it if someone could look into this issue and help resolve it.

Could you run the workload with compute-sanitizer and post the logs here, please?

Please find the full log text, which is too large (45M txt file), in the link below.

https://drive.google.com/file/d/1DRdwq3QjFbD6MyDZj0YHNtiuj2LFjqII/view?usp=sharing

Partial log text is given below.

========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 8 bytes
=========     at 0x1470 in void cusparse::load_balancing_kernel<(unsigned int)256, (unsigned int)1, (unsigned long)0, int, int, cusparse::CsrMMOpAlg1<cusparse::CsrMMPolicyAlg1<int, double, double, double>, (bool)0, (bool)0, (bool)1, double, int>, int, double, double, double>(const T5 *, T4, T5, T5, int, const T4 *, T6, T7 *...)
=========     by thread (64,0,0) in block (0,71,0)
=========     Address 0x7f5a24c00000 is out of bounds
=========     and is 1 bytes after the nearest allocation at 0x7f5a24a00000 of size 2,097,152 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x3050c2]
=========                in /lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0x8fea4b]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/../../../../libcusparse.so.11
=========     Host Frame: [0x95b5e8]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/../../../../libcusparse.so.11
=========     Host Frame: [0x824b1]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/../../../../libcusparse.so.11
=========     Host Frame: [0x4b9d79]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/../../../../libcusparse.so.11
=========     Host Frame: [0x4f836d]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/../../../../libcusparse.so.11
=========     Host Frame:cusparseSpMM [0xff1c9]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/../../../../libcusparse.so.11
=========     Host Frame:void at::native::sparse::cuda::(anonymous namespace)::_csrmm2<double>(char, char, long, long, long, long, double*, double*, int*, int*, double*, long, double*, double*, long, cudaDataType_t) [clone .constprop.0] [0x2e8c222]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:void at::native::sparse::cuda::csrmm2<double>(char, char, long, long, long, long, double, double*, int*, int*, double*, long, double, double*, long) [0x2e8e2f5]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::s_addmm_out_csr_sparse_dense_cuda_worker(long, long, long, long, at::Tensor const&, c10::Scalar const&, at::Tensor const&, c10::Scalar const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&)::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const [0x2e7e920]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::s_addmm_out_csr_sparse_dense_cuda_worker(long, long, long, long, at::Tensor const&, c10::Scalar const&, at::Tensor const&, c10::Scalar const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&) [0x2e818e4]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::s_addmm_out_sparse_dense_cuda_worker(long, long, long, long, at::Tensor&, c10::Scalar const&, at::Tensor const&, c10::Scalar const&, at::Tensor&, at::Tensor&, at::Tensor const&) [0x29639c1]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::s_addmm_out_sparse_dense_cuda(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x2964086]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::s_addmm_sparse_dense_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x2964880]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::addmm_sparse_dense_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x2964df7]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::(anonymous namespace)::(anonymous namespace)::wrapper_SparseCUDA__addmm(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x2d8833d]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_SparseCUDA__addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x2d883cd]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::_ops::addmm::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x21195a1]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::native::_sparse_addmm(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x1daa50a]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___sparse_addmm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x28b86dd]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::_ops::_sparse_addmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x20af9b6]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::VariableType::(anonymous namespace)::_sparse_addmm(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x3c0a623]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&), &torch::autograd::VariableType::(anonymous namespace)::_sparse_addmm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x3c0afe3]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::_ops::_sparse_addmm::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) [0x2119231]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::native::_sparse_mm(at::Tensor const&, at::Tensor const&) [0x1daf233]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd___sparse_mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x2a83fb0]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::_ops::_sparse_mm::call(at::Tensor const&, at::Tensor const&) [0x23e3741]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::THPVariable__sparse_mm(_object*, _object*, _object*) [0x66183d]
=========                in /home/srinath/miniconda3/envs/gpm/lib/python3.10/site-packages/torch/lib/libtorch_python.so
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Objects/methodobject.c:554:cfunction_call [0xfc697]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Objects/call.c:216:_PyObject_MakeTpCall [0xf614b]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/ceval.c:4181:_PyEval_EvalFrameDefault [0xf2376]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/ceval.c:5074:_PyEval_Vector [0x191d92]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/ceval.c:1135:PyEval_EvalCode [0x191cd7]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/pythonrun.c:1292:run_eval_code_obj [0x1c2967]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/pythonrun.c:1313:run_mod [0x1bdad0]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/pythonrun.c:1208:pyrun_file.cold [0x5956b]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/pythonrun.c:456:_PyRun_SimpleFileObject [0x1b805f]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Python/pythonrun.c:90:_PyRun_AnyFileObject [0x1b7dc3]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Modules/main.c:670:Py_RunMain [0x1b4b7d]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:/usr/local/src/conda/python-3.10.13/Modules/main.c:1091:Py_BytesMain [0x184e49]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
=========     Host Frame:../sysdeps/nptl/libc_start_call_main.h:74:__libc_start_call_main [0x23a90]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:../csu/libc-start.c:347:__libc_start_main [0x23b49]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame: [0x184cfe]
=========                in /home/srinath/miniconda3/envs/gpm/bin/python
.
.
.
.
Invalid __global__ read of size 8 bytes
=========     at 0x1470 in void cusparse::load_balancing_kernel<(unsigned int)256, (unsigned int)1, (unsigned long)0, int, int, cusparse::CsrMMOpAlg1<cusparse::CsrMMPolicyAlg1<int, double, double, double>, (bool)0, (bool)0, (bool)1, double, int>, int, double, double, double>(const T5 *, T4, T5, T5, int, const T4 *, T6, T7 *...)
=========     by thread (64,0,0) in block (0,71,0)
=========     Address 0x7f5a24c00000 is out of bounds

Thank you for the logs! Could you update to the latest nightly with shipping with CUDA 12.1 and check if the error is reproducible?

Just tried with the latest nightly (pytorch 2.2.0.dev20230921, bulid: py3.10_cuda12.1_cudnn8.9.2_0). the error is reproducible even with CUDA 12.1.

(test_env) srinath@vapnik:~/test$ conda list | grep 'pytorch'
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
pytorch                   2.2.0.dev20230921 py3.10_cuda12.1_cudnn8.9.2_0    pytorch-nightly
pytorch-cuda              12.1                 ha16c6d3_5    pytorch-nightly
pytorch-mutex             1.0                        cuda    pytorch-nightly
torchaudio                2.2.0.dev20230921     py310_cu121    pytorch-nightly
torchtriton               2.1.0+6e4932cda8           py310    pytorch-nightly
torchvision               0.17.0.dev20230921     py310_cu121    pytorch-nightly
(test_env) srinath@vapnik:~/test$ python test_sparse_mm.py 
 i :0 

0 (1, 1) torch.Size([1, 1202])
 i :0 

1000 (1815, 1) torch.Size([1, 1694])
1205 (956, 1) torch.Size([1, 1612])
Traceback (most recent call last):
  File "/home/srinath/test/test_sparse_mm.py", line 27, in <module>
    mat = torch.sparse.mm(sparse, dense)
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Thanks for the quick test and the code snippet!
I was also able to reproduce the issue and have forwarded it to our cuSPARSE team to take a look at it.

1 Like

Thank you for forwarding this issue to cuSPARSE team! Looking forward to hearing from you soon.

For anyone encountering this problem, the temporary solution is to use older versions of Pytorch with CUDA 11.3 or CUDA 11.6 (tried with Pytorch 1.11 - py3.8_cuda11.3_cudnn8.2.0_0 and PyTorch 1.12.1 - py3.10_cuda11.6_cudnn8.3.2_0). I suggest downgrading to PyTorch 1.12.1 with CUDA 11.6 as its GPU memory usage is less than CUDA 11.3 for sparse dense matrix operations. If one doesn’t want to downgrade Pytorch, one can use the following sparse-dense multiplication function from Python library pytorch-sparse.

torch_sparse.spmm