Training hangs when using `triangular_solve`, unsure how to debug

I’m experimenting with a model that involves multiplications and back-solves on large, sparse tensors as part of the forward pass. To get right to the point, the problem I’m running into is that training will occasionally just stall randomly (stop progress, GPU util sits at 100% according to nvidia-smi), some number of steps into training, and I’m having trouble even figuring out how to go about debugging it. Below I’ll post a description of what I’m doing, the problem I’m encountering, and what I’ve tried so far. Unfortunately, I haven’t been able to get code that reproduces this exact error outside of my own model and dataset, but I believe I have found race conditions that fail when running with NVIDIA’s compute-sanitizer, which I’ll explain below.

My torch version is '1.12.0+cu116' and I’m running on an NVIDIA A100. This is what my nvidia-smi shows:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+

Forward pass description
Like I said, I have large, sparse matrices I would like to multiply by as part of the forward pass. There is also a large, sparse pair of triangular tensors (an LU decomposition) which I want to backsolve on as well, to avoid having to instantiate a dense matrix to do a multiplication with. All of the sparse matrices are part of the dataset, so they’re each fixed. A toy version of my forward pass would look something like:

class MyLayer(nn.Module):
    def __init__(self, channels):
        super(MyLayer, self).__init__()
        self.conv_1 = nn.Linear(channels, channels)
        self.conv_2 = nn.Linear(channels, channels)

    def forward(self, x, L, U, A, B):
        # x: [n, c] - no batch size
        # L, U: [n, n] - sparse, triangular tensors to backsolve on
        # A, B: [n, n] - sparse tensors to multiply by

        x = self.conv_1(x)

        # Cast to double for these operations
        x = x.to(torch.double)

        x = A @ x

        x = solve_triangular(L, x, False)
        x = solve_triangular(U, x, False)

        x = B @ x

        # Cast back to float
        x = x.to(torch.float)

        x = self.conv_2(x)

        return x

Regarding the use of solve_triangular, this is the code I’m using:

class SolveTriangular(torch.autograd.Function):
    # NOTE: This uses the syntax of torch.linalg.solve_triangular() even though it
    # calls torch.triangular_solve() for implementation
    @staticmethod
    def forward(ctx, A, B, upper):
        ctx.save_for_backward(A)
        ctx.upper = upper
        X = torch.triangular_solve(B, A, upper=upper).solution
        return X

    @staticmethod
    def backward(ctx, grad_X):
        (A,) = ctx.saved_tensors
        upper = ctx.upper
        grad_A = grad_B = None

        assert not ctx.needs_input_grad[0], "Gradient of A not supported"
        if ctx.needs_input_grad[1]:
            grad_B = torch.triangular_solve(
                grad_X, A, upper=upper, transpose=True
            ).solution
        return grad_A, grad_B, None


solve_triangular = SolveTriangular.apply

Although the docs say torch.triangular_solve is depreceated in favor of torch.linalg.solve_triangular, I’m using the first one because it is the only one that works with sparse tensors. I also looked into torch.lu_solve, but as far as I can tell, that function doesn’t work with sparse tensor arguments either, so it wouldn’t work in my case.

Problem description
When I launch the model to train, it seems to run fine, and loss goes down as it takes steps. However, after variable amounts of time (sometimes within 3 epochs, sometimes in 15), the tqdm progress bar just stops, and will stay frozen indefinitely, as far as I can tell. When I look at nvidia-smi, the GPU utilization is pegged at 100%, when during normal training it fluxuates. The first time this happens, if I ctrl + C, then I get a stacktrace that looks like some variation of this:

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 199, in _finalize_join
    thread.join()
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 199, in _finalize_join
    thread.join()
  File "/usr/lib/python3.10/threading.py", line 1096, in join
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.10/threading.py", line 1096, in join
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
  File "/usr/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):

However, if I restart the model, the next time it hangs, ctrl + C doesn’t do anything, and I have to find the PID from nvidia-smi and then do kill [PID] to get it to stop.

What I’ve tried

  1. I tried running the code in python gdb, and then when it hangs, exiting to see which line it hangs on. This told me the line it’s hanging on is consistently torch.triangular_solve, and it can happen in both the train loop and eval loop.
  2. I have saved the inputs to each call of torch.triangular_solve, and then once I encountered a hang, loaded those inputs in a script and called torch.triangular_solve on them in a loop, and that did not reproduce the issue.
  3. This was a while ago, and I don’t understand the details of what cuda driver versions mean, but I followed some forum posts on nvidia’s website somewhere to install the latest drivers and compute-sanitizer, and then ran the model inside of that. The --racecheck tool told me that there were two race conditions in the code. I ran the compute-sanitizer on a small script which just called torch.triangular_solve on random inputs as well, and the racecheck identified two errors there as well.

Since installing the compute-sanitizer, I had to reset all my installs/settings, and now I’m back on cuda 11.4 according to my nvidia-smi, and I don’t have compute-sanitizer anymore, I have cuda-memcheck, but that behaves differently. When I run it in memcheck mode, I get this sort of error:

========= Program hit CUDA_ERROR_NOT_FOUND (error 500) due to "named symbol not found" on CUDA API call to cuGetProcAddress.
=========     Saved host backtrace up to driver entry point at error
=========     Host Frame:/usr/lib/x86_64-linux-gnu/libcuda.so.1 [0x212ae0]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libcudart-45da57e3.so.11.0 [0x27460]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libcudart-45da57e3.so.11.0 [0x2e8d7]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libcudart-45da57e3.so.11.0 [0x316c8]
=========     Host Frame:/lib/x86_64-linux-gnu/libpthread.so.0 [0xf907]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libcudart-45da57e3.so.11.0 [0x75ce9]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libcudart-45da57e3.so.11.0 [0x23737]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libcudart-45da57e3.so.11.0 (cudaGetDeviceProperties + 0x46) [0x47646]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda_cpp.so [0x133528]
=========     Host Frame:/lib/x86_64-linux-gnu/libpthread.so.0 [0xf907]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda_cpp.so (_ZN2at4cuda19getDevicePropertiesEl + 0xd7) [0x133727]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libtorch_python.so [0x9a9878]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libtorch_python.so [0x36bc55]
=========     Host Frame:python3 [0x10d454]
=========     Host Frame:python3 (_PyObject_MakeTpCall + 0x15e) [0x212a3e]
=========     Host Frame:python3 (_PyEval_EvalFrameDefault + 0x3aa) [0x18f12a]
=========     Host Frame:python3 [0x18e136]
=========     Host Frame:python3 [0xbbb9c]
=========     Host Frame:python3 [0x18e136]
=========     Host Frame:python3 [0xbbb9c]
=========     Host Frame:python3 [0x18e136]
=========     Host Frame:python3 [0xbbb9c]
=========     Host Frame:python3 [0x18e136]
=========     Host Frame:python3 [0xe379a]
=========     Host Frame:python3 [0x1ff215]
=========     Host Frame:python3 (PyObject_CallMethod + 0xce) [0x1ff6ee]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libtorch_python.so (_ZN5torch5utils14cuda_lazy_initEv + 0x4d) [0x95e4fd]
=========     Host Frame:/home/loganspear/tlaloc/python/.venv/lib/python3.10/site-packages/torch/lib/libtorch_python.so [0x415800]
=========     Host Frame:python3 [0x117229]
=========     Host Frame:python3 [0xbcb3c]
=========     Host Frame:python3 [0x18e136]
=========     Host Frame:python3 (PyEval_EvalCode + 0x117) [0x18e6a7]
=========     Host Frame:python3 [0x252d2c]
=========     Host Frame:python3 [0x10d7d2]
=========     Host Frame:python3 [0xbbb9c]
=========     Host Frame:python3 [0x18e136]
=========     Host Frame:python3 [0xbbb9c]
=========     Host Frame:python3 [0x18e136]
=========     Host Frame:python3 (_PyObject_Call + 0x1e6) [0x1ff8f6]
=========     Host Frame:python3 [0x2c29ef]
=========     Host Frame:python3 (Py_RunMain + 0x14e) [0x2c2bce]
=========     Host Frame:python3 (Py_BytesMain + 0x29) [0x2c3029]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xe7) [0x21c87]
=========     Host Frame:python3 (_start + 0x2a) [0x21754a]

When I run it with --tool rachecheck, the toy script I wrote which just calls my version of solve_triangular on random triangular matrices actually breaks, even though it runs fine if I run it outside of cuda-memcheck --tool racecheck, and this is the error I get, when I run it with CUDA_LAUNCH_BLOCKING=1:

========= CUDA-MEMCHECK
  L = torch.from_numpy(L).to_sparse_csr().to(device)
  0%|                                                                                                                                                               | 0/1000 [03:37<?, ?it/s]
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "...", line 104, in <module>
    loss.backward()
  File ".../lib/python3.10/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File ".../lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: an illegal memory access was encountered
========= RACECHECK SUMMARY: 0 hazards displayed (0 errors, 0 warnings)

Conclusion
I’d love to be able to perform a backsolve on sparse, triangular tensors on the GPU. I’m running into this problem which I’m just not sure how to go about debugging, and welcome any suggestions. I’m also trying to learn more about exactly what “cuda drivers” are and what their version means, since I’m confused as to why my cuda version is 11.4 but that still works with Pytorch when it has version 11.6, so if anyone has links to good resources I could read up on that about, I would appreciate that too. But most importantly, if anyone has suggestions on how I could further go about debugging this, I would greatly appreciate it.

Your compute-sanitizer output isn’t detecting an error in PyTorch, but seems to directly fail while trying to execute the code:

========= Program hit CUDA_ERROR_NOT_FOUND (error 500) due to "named symbol not found" on CUDA API call to cuGetProcAddress.

which might point towards a setup issue.

Could you post a minimal, executable code snippet which would reproduce the illegal memory access, please?