Threaded Inference c10::CuDNNError

I’m running into issues with threaded inference (multiple threads using the same single GPU model) when using multiple model instances. I took a quick check for previous posts and don’t quite see something similar.

EDIT: I should add, I am using the following versions:

$ cat /usr/local/libtorch/build-version
2.0.1+cu118

$ g++ --version
g++ (Ubuntu 11.3.0-6ubuntu1) 11.3.0

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

$ cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 9
#define CUDNN_PATCHLEVEL 1

For an individual model type, I have it wrapped where threads sends query to a thread-safe queue and receive back inferences. On the model side, it just sits in a separate thread which pulls from the queue, run through the model, then send back. This works fine when using multiple threads for inference (think running a search algorithm, each thread has a separate problem, and all using the same single GPU to query).

The issues arises when I have a second model (for separate query types). Since it uses the same wrapper, no two threads will hit the same model at the same time. However, thread 1 can run inference on model A at the same time thread 2 runs inference on model B.

Its a pretty big code base so its hard to share the details, but to diagnose I have confirmed that using only model A or model B works, using both models but only a single thread for inference works, and throwing in a global mutex lock (so thread 1 cannot query model A at the same time thread 2 queries model B) also works.

Is this not something supported? Can two threads run inference on separate models at the same time, or is there something with the cuda/cudnn backend which makes it problematic for calling (potentially) the same kernel (for different models) at the same time from separate threads? Here is the runtime crash I get:

terminate called after throwing an instance of 'c10::CuDNNError'
  what():  cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
Exception raised from run_conv_plan at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:224 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7fbf9e05a6bb in /usr/local/libtorch/lib/libc10.so)
frame #1: <unknown function> + 0x1217a0b (0x7fbf9f617a0b in /usr/local/libtorch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x121f4b0 (0x7fbf9f61f4b0 in /usr/local/libtorch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x1220795 (0x7fbf9f620795 in /usr/local/libtorch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x12090ca (0x7fbf9f6090ca in /usr/local/libtorch/lib/libtorch_cuda.so)
frame #5: at::native::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) + 0xa5 (0x7fbf9f6095c5 in /usr/local/libtorch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0x2dd1f71 (0x7fbfa11d1f71 in /usr/local/libtorch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x2dd200f (0x7fbfa11d200f in /usr/local/libtorch/lib/libtorch_cuda.so)
frame #8: at::_ops::cudnn_convolution::call(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) + 0x21f (0x7fbff72503ff in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #9: at::native::_convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool) + 0x12dd (0x7fbff655b5ed in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x27e16bc (0x7fbff75e16bc in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x27e1784 (0x7fbff75e1784 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #12: at::_ops::_convolution::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, bool, c10::ArrayRef<c10::SymInt>, long, bool, bool, bool, bool) + 0x2b5 (0x7fbff6dee3f5 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #13: at::native::convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0x15f (0x7fbff655215f in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x27e12a2 (0x7fbff75e12a2 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x27e1322 (0x7fbff75e1322 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #16: at::_ops::convolution::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, bool, c10::ArrayRef<c10::SymInt>, long) + 0x23e (0x7fbff6dbc5fe in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x3a802ed (0x7fbff88802ed in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x3a80f66 (0x7fbff8880f66 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #19: at::_ops::convolution::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, bool, c10::ArrayRef<c10::SymInt>, long) + 0x23c (0x7fbff6ded83c in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #20: at::native::conv2d(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long) + 0x224 (0x7fbff65552e4 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #21: <unknown function> + 0x29a9bd2 (0x7fbff77a9bd2 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #22: at::_ops::conv2d::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long) + 0x201 (0x7fbff736b991 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #23: <unknown function> + 0x51004e2 (0x7fbff9f004e2 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #24: torch::nn::Conv2dImpl::_conv_forward(at::Tensor const&, at::Tensor const&) + 0x44d (0x7fbff9ef961d in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #25: torch::nn::Conv2dImpl::forward(at::Tensor const&) + 0x24 (0x7fbff9ef9864 in /usr/local/libtorch/lib/libtorch_cpu.so)
frame #26: <unknown function> + 0x4566f7 (0x563432e8e6f7 in ./src/main)
frame #27: <unknown function> + 0x3f8891 (0x563432e30891 in ./src/main)
frame #28: <unknown function> + 0x419449 (0x563432e51449 in ./src/main)
frame #29: <unknown function> + 0x41982d (0x563432e5182d in ./src/main)
frame #30: <unknown function> + 0x3f0490 (0x563432e28490 in ./src/main)
frame #31: <unknown function> + 0x3f3b77 (0x563432e2bb77 in ./src/main)
frame #32: <unknown function> + 0x3e0a0c (0x563432e18a0c in ./src/main)
frame #33: <unknown function> + 0x242a5b (0x563432c7aa5b in ./src/main)
frame #34: <unknown function> + 0xdc3a3 (0x7fbf972dc3a3 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #35: <unknown function> + 0x90402 (0x7fbf96e90402 in /lib/x86_64-linux-gnu/libc.so.6)
frame #36: <unknown function> + 0x11f590 (0x7fbf96f1f590 in /lib/x86_64-linux-gnu/libc.so.6)

Aborted (core dumped)

I also can’t really reproduce this in a Debug build, as I imagine the slower code reduces the likelihood of two threads hitting model A and model B at the same time.

Interesting, cuDNN handles are ostensibly safe so I’m not sure what the bad interaction here is. Does setting the environment variable TORCH_CUDNN_V8_API_DISABLED=1 when running your workload produce a similar error?

I removed the previous global mutex lock fix and ran again to verify I was still crashing without the change (same error as above). Then I set export TORCH_CUDNN_V8_API_DISABLED=1 and did a clean recompile, and I have not crashed yet (after letting it run for much longer than what would otherwise result in a crash).

The only side effect I have is that the code is a bit slower (400 seconds vs 750 seconds for my search benchmark which is total time not just GPU time). The inferences from the network (with training steps in between) produced the same values up until the crash before setting that variable.

Any idea what this could mean?

I think it could be that there is some part of the V8 frontend API that is not thread-safe. Would it be possible to provide a smaller repro that reproduces the problem?

The performance difference could be because the kernel selection could be suboptimal in v7 compared to v8. You might want to check if turning on cuDNN benchmarking can help somewhat (though it wouldn’t help e.g., if your workload makes heavy use of dynamic shapes): at::globalContext().setBenchmarkCuDNN(true);

Ah I do have the benchmarking turned off for reproducibility, but I can play around with that.

I’ll see if I can put together a minimal working example in a repo with minimal dependencies, and will let you know. If so, should I comment here with a ping or just open an issue on the main pytorch github?

Both work, if you open a github issue just CC me @eqy (same username).

Will do, thanks for your help!

1 Like

Here is the github issue for tracking purposes: TORCH_CUDNN_V8_API Thread Safety Bug · Issue #103793 · pytorch/pytorch · GitHub

Thanks for the repro! I hope to have some answers by the middle of next week.

Just writing to say that this was a great conversation to follow.