Memory leaks due to Python exception handling

I have been receiving a strange CUDA error that I couldn’t place without a lot of searching:

Traceback (most recent call last):
  File "/panfs/roc/groups/13/suo-yang/dikem003/DimensionReductionNLE/auto_ode/AETrainingConditionNum.py", line 241, in <module>
    batch_loss = reconstruction_loss + stiffness_loss*EPOCH_SCALER
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

The line reported in the stack track is pretty benign, but after searching and finding this forum post I found out that Python exceptions can cause issues in memory deallocation in CUDA.

In my objective function, I have to invert a matrix and since I choose the data I use stochastically from each batch I can’t be sure if the matrix is invertible or not, and have a try…catch statement in case this operation fails. This looks like this:

 for batch_num, start_samp in zip(range(num_batches), start_points):
    try:
        start_idx = start_samp
        end_idx = start_samp + self.sample_size
        # slice offset matrices from dynamics
        Y1, Y2 = torch.transpose(latent_batches[batch_num,start_idx:end_idx-1,:], 0, 1), torch.transpose(latent_batches[batch_num,start_idx+1:end_idx,:], 0, 1)
        # find linearized Jacobian ODE dynamics
        D = (Y2 @ torch.transpose(Y1,0,1)) @ torch.linalg.inv(Y1 @ torch.transpose(Y1,0,1))

        # find condition number of estimated dynamics matrix
        cond = torch.linalg.cond(D, p=2)
        # add calculated condition number to accumulator
        condition_number_penalty = torch.add(cond, condition_number_penalty)
    except RuntimeError as err:
        print('singmatrix')

Am I right in thinking this error is a potential memory deallocation problem within CUDA from this exception? If so, is there something I can do directly (keeping the exception handling that I have) that can fix the issue?

One other potential cause of the issue might be the case where every iteration in the training hits the exception, so that I am just passing a torch.tensor([0]) to my loss function. It never causes an issue on a CPU, but due to some CUDA optimization with the computation graph could that behavior cause an issue when training on GPU?

One thing I was thinking of doing is making a copy of the matrix I want to invert that is detached from the computation graph that I can then perform some check with to make sure that this matrix is invertible. Before I spend too much time checking potential solutions though I would like to make sure I have diagnosed the problem correctly

Edit: forgot to mention, the reason I think that this exception catching is the cause of the problem is

  1. typically a job will fail right AFTER a few of the print statements in the catch statement are shown
  2. other custom loss functions I use which are treated identically but without the try…catch and matrix inversion work fine

No, I don’t think the Python exception handling could cause an illegal memory access and should keep the tensor alive without releasing it to the cache in the worst case.
To further isolate the issue, you could rerun your script via CUDA_LAUNCH_BLOCKING=1 as suggested in the error message.

Hi @ptrblck, I was actually able to solve the issue after a couple days of debugging. When I did run the script with CUDA_LAUNCH_BLOCKING=1 it actually just returned an identical traceback and error, minus the tip telling me to run with the flag, so I just had to change things until it worked

The issue ended up being that even though I created the accumulated loss tensor as a 0D tensor, when I accumulated loss over each sample using torch.add I was unaware that torch.cond sometimes returns a 0D tensor, sometimes a 1D tensor with 1 element depending on which method you select, which you can see in their examples section. This generated some unexpected behavior since I switched the method of calculating the condition number around the same time I added the exception handling, so I was calling backward on a tensor of size [1] when using my accumulated loss

A couple things seem strange to me: 1) why would I get such an ambiguous error message for a simple error of calling backward on a tensor of the wrong dimension? and 2) torch.cond returns tensors of different dimensions for different methods without much mention in the documentation

The latter is what burned me I think. All I have to do is call torch.squeeze on the output of torch.cond before accumulating the loss and now it works fine

  1. You shouldn’t get an illegal memory access, as assert statements should guard it properly. Which PyTorch version are you using and could you post a minimal code snippet to reproduce the memory violation in case you are using the latest release?

  2. If you are using the latest release and are facing the shape issue, could you create an issue/feature request on GitHub so that we could check if changing the returned shape would be feasible?

cudatoolkit               10.2.89
torch                     1.9.0

These are my CUDA and torch versions, I am certain PyTorch is current since I upgraded recently, within the last 48h, unsure about CUDA but I think so.

  1. On the first point, you may have been right. After doing what I thought was the fix, I still ran into this error:
Traceback (most recent call last):
  File "/panfs/roc/groups/13/suo-yang/dikem003/DimensionReductionNLE/auto_ode/AETrainingConditionNum.py", line 238, in <module>
    stiffness_loss = AE_stiffness_loss(latent_predictions) * EPOCH_SCALER
RuntimeError: CUDA error: an illegal memory access was encountered

After it failed again, I added some debugging statements in my main training loop to see the sizes of returned matrices from both loss functions (while changing nothing else):

# print shapes
print('\nRL Size',reconstruction_loss.size())
print('\nSL Size',stiffness_loss.size())

Now suddenly the error does not appear, as if referencing the size of both loss tensors before I call backward has eliminated the issue altogether. I even checked my git diff to see if I changed anything else and I have not, I only added print statements. I have made it to ~100 epochs at the time of writing whereas it was failing consistently at around 5-10 before.

I can try to develop a minimal reproducible program with this loss function since I know my one example isn’t worth much, I will need to figure out how to trigger the noninvertibility exception without my specialized set of data. Right now I sort of feel like I’m chasing Schrodinger’s bug, driving me nuts. I don’t see any reason why checking the loss tensor size in my training loop would affect anything.

  1. The shape issue for torch.cond is reflected in the documentation, and I can also generate a small program to demonstrate this and raise an Github issue, no problem there.

Could you try to create a cuda coredump using the code, which is failing quickly via:

export CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1
python script.py args

Once the IMA is triggered, a file should have been created in the working directory, which you could use to get a valid stacktrace.
Running with cuda-gdb might also work, but since you are seeing different behavior when printing stuff, it could point towards a potential race condition, so I would try to avoid slowing down the code.

Hi @ptrblck, I have generated the coredump, but I’ll be honest I’ve never used the CUDA debugger before so I’m not sure what to do next. I don’t expect you to walk me through it but if you can point me towards some useful documentation that can allow me to get a stacktrace I would appreciate it. I do somewhat worry I might not be able to make sense of it even if I do obtain a valid stacktrace since my experience is limited to PyTorch rather than programming in CUDA directly

If you think that there is a potential race condition, do you think that there is some place in my program I could try manually clearing the cache on the GPU that might eliminate my issue? Do you think it actually is some issue with the Python exception handling causing this? I feel myself quickly getting out of my depth trying to fix this

Either way, thanks a ton for your help so far

Great news that you were able to create the coredump.
Could you forward me the stacktrace via:

# launch cuda-gdb
cuda-gdb

# inside cuda-gdb
target cudacore file_name_of_coredump
bt

I don’t expect you to fix these issues, but it would be great if you could provide the corresponding code and, if possible, the coredump itself (it might be huge, so you might need to use e.g. Google Drive).
Inside the coredump you would be able to see the $pc as well as the last instructions, but to make sense of these you could need to be a bit familiar with the GPU machine instructions, so let’s focus on the backtrace first.

Since I have not used the debugger before, I wasn’t able to invoke cuda-gdb from the command line, I’m unsure if I need to download the debugger separately in order to do so. Does cuda-gdb come automatically installed with the CUDA-toolkit version that I previously mentioned? (which I believe I installed with either conda or pip) If so, how do I run it from the command line? And if not I would need to figure out a way to get it separately through my account at our supercomputing institute

Not sure if I mentioned that I am on a remote machine via SSH and running these jobs with Slurm in the V100 GPU cluster at my university.

I have no problem sharing the code and coredump with you, but since this is in relation to unpublished work I would probably need to send the file links to you over the messaging on this forum or somewhere else rather than post the links or files openly in this thread.

Edit: I checked the available modules at our SC institute, we have CUDA versions 6.5-11.2, CUDNN 3.0-8.2, CUDA SDK 6.5-11.2. I can access any of those if that can get the debugger up and running

No, cuda-gdb comes with a locally installed CUDA toolkit.
As a quick way to share the backtrace you could start a docker container with an installed CUDA toolkit e.g. nvidia/cuda:11.4.0-devel-ubuntu20.04. You can then start it and mount the folder where the coredump is located and execute cuda-gdb inside it:

docker run -it --gpus=all  -v /folder_with_coredump:/debug nvidia/cuda:11.4.0-devel-ubuntu20.04

The coredump will then be located in /debug inside the container.

No, you shouldn’t share unreleased code with anyone, so let’s try to get the stacktrace first.
One interesting point would be to know, if you are using any custom CUDA extensions or plain PyTorch code?

Hi @ptrblck,

I will work on getting the stacktrace. It would be much easier for me to load a preexisting CUDA installation through my SC institute than to use a Docker container, I don’t have much experience using Docker.

As for your second question, I am not using any CUDA extensions, I am solely using vanilla PyTorch code.

Edit: Here is the backtrace –

(cuda-gdb) target cudacore core_1627076158_cn2102_2891958.nvcudmp

Opening GPU coredump: core_1627076158_cn2102_2891958.nvcudmp

[New Thread 2892012]

CUDA Exception: Warp Illegal Address

The exception was triggered at PC 0x561e719946f0

[Current focus set to CUDA kernel 0, grid 1263657, block (0,0,0), thread (0,0,0), device 0, sm 0, warp 3, lane 0]

#0 0x0000561e71994720 in void laswp_kernel2<double, false>(int, double*, unsigned long, int, int, int const*, int)<<<(1,1,1),(4,64,1)>>> ()

(cuda-gdb) bt

#0 0x0000561e71994720 in void laswp_kernel2<double, false>(int, double*, unsigned long, int, int, int const*, int)<<<(1,1,1),(4,64,1)>>> ()

I think it spit out the backtrace automatically, but I used the bt command once more just to be sure. This was generated with a loaded local installation of cuda-gdb that came with CUDA 11.2

Probably should have given this information upfront, but all my training was conducted on a single V100 GPU

@ptrblck unsure if this clarifies anything, but after my loss accumulation loop (the one that includes the try…catch statement that I showed above) I added a call to torch.cuda.synchronize and received an identical error but with a stacktrace that pointed now to the call to synchronize:

Traceback (most recent call last):
  File "/panfs/roc/groups/13/suo-yang/dikem003/DimensionReductionNLE/auto_ode/AETrainingConditionNum.py", line 238, in <module>
    stiffness_loss = AE_stiffness_loss(latent_predictions) * EPOCH_SCALER
  File "/home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/panfs/roc/groups/13/suo-yang/dikem003/DimensionReductionNLE/auto_ode/CustomLossFunctions.py", line 226, in forward
    torch.cuda.synchronize(self.DEVICE)
  File "/home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/cuda/__init__.py", line 446, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Also, after exporting flags for not caching memory and the debugging flag
export PYTORCH_NO_CUDA_MEMORY_CACHING=1
export CUDA_LAUNCH_BLOCKING=1

I got a more helpful stacktrace

terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: an illegal memory access was encountered
Exception raised from uncached_delete at /pytorch/c10/cuda/CUDACachingAllocator.cpp:1307 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f09bf565a22 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xfa5e (0x7f09bf7c5a5e in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::TensorImpl::release_resources() + 0x54 (0x7f09bf54f5a4 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #3: <unknown function> + 0xe5e3e9 (0x7f09c08413e9 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x16b4842 (0x7f09c1097842 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #5: at::native::_linalg_inv_out_helper_cuda_lib(at::Tensor&, at::Tensor&, at::Tensor&) + 0x374 (0x7f09c109e244 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0xf22ffa (0x7f09c0905ffa in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #7: at::_linalg_inv_out_helper_(at::Tensor&, at::Tensor&, at::Tensor&) + 0x133 (0x7f0a03625093 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x1001850 (0x7f0a02ee8850 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::native::linalg_inv_ex_out(at::Tensor const&, bool, at::Tensor&, at::Tensor&) + 0xc3 (0x7f0a02ee9103 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::native::linalg_inv_ex(at::Tensor const&, bool) + 0xe6 (0x7f0a02ee9366 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x1af7e40 (0x7f0a039dee40 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #12: at::redispatch::linalg_inv_ex(c10::DispatchKeySet, at::Tensor const&, bool) + 0xb1 (0x7f0a036fefc1 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x2f90d8b (0x7f0a04e77d8b in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x2f91223 (0x7f0a04e78223 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #15: at::linalg_inv_ex(at::Tensor const&, bool) + 0x11b (0x7f0a03546dbb in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #16: at::native::linalg_inv(at::Tensor const&) + 0x30 (0x7f0a02edfbf0 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x1b71fcc (0x7f0a03a58fcc in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #18: at::linalg_inv(at::Tensor const&) + 0x111 (0x7f0a03409271 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #19: <unknown function> + 0x8dd4a4 (0x7f0a15f734a4 in /home/suo-yang/dikem003/.conda/envs/torchcombust/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #41: __libc_start_main + 0xf5 (0x7f0a18b57555 in /lib64/libc.so.6)

/var/spool/slurmd/job5169384/slurm_script: line 27: 2939038 Aborted                 python3 -u AETrainingConditionNum.py -an=softplus_ratiolowersing_high -aa=softplus -ad=2 -rn=null -cs=500 -cw=0.1 --condlosstype=ratiolowersingular > lowersing2.out

I see mention of the linalg.inv function, which is the one I am catching the exception from (when a given matrix is not invertible). So seems to be some issue with trying to delete a cached file from the exception that may not exist?

Would the implication of this be that there is no way to currently catch an exception from a inversion operation while using CUDA

Thanks for the updates!

The coredump points to the laswp MAGMA function, which seems to trigger the illegal memory access, and is called by linalg_inv.
@xwang233 has been working with others on using cuSOLVER instead of MAGMA (e.g. as we’ve seen such memory violations in MAGMA before), but based on this comment batched inputs seem to still call into MAGMA for performance reasons.

Are you seeing the memory violation only, if the exception would have been caught?

Yes, I am only seeing this illegal memory access exception raised once the singular matrix exception is caught. When I happen to weight this custom loss function lower in comparison to my reconstruction loss (this program is for a regularized autoencoder) my matrices never become singular and training proceeds as normal. I can tell since I print a small warning statement to the console when the exception is triggered

However, I have noticed that typically several exceptions will be caught before the memory access error is raised. Somewhat strange also that the two debug print statements mentioned earlier mitigated the issue (I even saw several warnings for the exception being raised within the log file, but training continued), I would think once I catch this exception the rest is deterministic and the CUDA exception would definitely be raised. Is there potentially some delay in communication between Python try-catch statements and the GPU? I’m not sure how far in advance CUDA tends to queue up operations

It’s also interesting that the function causing the issue is typically called only for batch inputs, since the for loop in the original code snippet is looping through each batch of inputs and calling torch.linalg.inv separately. In fact, each call to inverse which raises the RuntimeError that must be caught is only upon a single matrix of size [20, 20]. Is the cuSOLVER change a future switch or should this already be in effect for PyTorch 1.9?

Just want to take another opportunity to again say thank you for your help, and thank you for your patience in dealing with a primarily plain PyTorch user. Very very much appreciated!

The “strangeness” of the error reporting is most likely due to the asynchronous execution of CUDA kernels and since e.g. print statements might be synchronizing, you might indeed either mask or avoid the issue. Especially if you are dealing with race conditions, changing the execution order or speed could “hide” the memory violation and the illegal memory reads or writes could not be hit.
cuSOLVER was already activated for a lot of methods as described in this blog post, but due to some performance issues were still seeing compared to MAGMA (we are working with cuSOLVER to accelerate these methods), MAGMA is still used for some workloads.
However, since you are apparently seeing the IMA in MAGMA, we might consider trading a slower execution in cuSOLVER for this functional issue in MAGMA.

EDIT: just as one example of debugging a race: some time ago I was also hunting down an invalid output of another method and was able to reproduce it only when I’ve pushed the GPU utilization to max. in another process, which then slowed down the faulty code and allowed me to narrow down a missing CUDA stream synchronization. :stuck_out_tongue:

Understood, I can see now how the asynchronicity can make it considerably more difficult to track down these sorts of errors. I was also unaware that linalg was so new, it makes sense issues like this would come up! I will definitely keep track of PyTorch changes on the MAGMA vs cuSOLVER front

In my case, do you have any advice for avoiding this particular exception? Alternatives to torch.linalg.inv, ways of manipulating the GPU cache that might help, delaying or syncing in certain parts of the program? I know you might not have an exact answer, but any potential avenues that might solve this illegal memory access would be great to get me started. It’d be great to get some version of this loss function working even if these issues can’t be resolved cleanly until later changes to torch.linalg