Unable to run a single convolutional layer in different CUDA-contexts

I’m trying to design a Real-TIme scheduler using Post-Volta MPS by creating CUDA-contexts with different SM counts. In this scheduler, I need to run a single model in different contexts on each iteration. The problem is that it works fine with non-convolutional layers like fully-conencted layer, but when I change the context and run a convolutional layer, I get CUDNN_STATUS_MAPPING_ERROR error. This is a sample code to reproduce the error. As you might see, I create two additional CUDA-contexts (with one already created default context), then run a fully-connected layer using all three contexts and it runs smoothly, but when it comes to the convolutional one, I get the error.

import torch
import torch.nn as nn
from cuda import cuda

inl = torch.rand(128, 128, device="cuda")
lin = nn.Linear(128, 128, device='cuda')

inc = torch.ones(50, 30, 10, 5, device="cuda").share_memory_()
conv = nn.Conv2d(30, 5, 3, stride=1, padding=1, device='cuda').share_memory()

def create_context(sm_count):
	affinity = cuda.CUexecAffinityParam()
	affinity.type = cuda.CUexecAffinityType.CU_EXEC_AFFINITY_TYPE_SM_COUNT
	affinity.param.smCount.val = sm_count

	ctx = cuda.cuCtxCreate_v3([affinity], 1, 0, 0)[1]
	cuda.cuInit(0)
	
	return ctx

# Creating two more contexts
ctx1 = create_context(10)
ctx2 = create_context(40)

# Trying Fully Connected layer
cuda.cuCtxSetCurrent(0) # Sets default context
dummy = lin(inl)

cuda.cuCtxSetCurrent(ctx1)
dummy = lin(inl)

cuda.cuCtxSetCurrent(ctx2)
dummy = lin(inl)

# Trying with Convolutional layer
cuda.cuCtxSetCurrent(0) # Sets default context
dummy = conv(inc)

cuda.cuCtxSetCurrent(ctx1)
dummy = conv(inc)

cuda.cuCtxSetCurrent(ctx2)
dummy = conv(inc)

This is a very simple Python version of what I’m doing, but my main project is based on the C++ API which I get the same error. Here is the full error code I get in C++ version:

terminate called after throwing an instance of 'c10::CuDNNError'
  what():  cuDNN error: CUDNN_STATUS_MAPPING_ERROR
Exception raised from getCudnnHandle at ../aten/src/ATen/cudnn/Handle.cpp:48 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7ffa619f07d2 in /home/amir/repos/libtorch/lib/libc10.so)
frame #1: at::native::getCudnnHandle() + 0x427 (0x7ff9fe23be57 in /home/amir/repos/libtorch/lib/libtorch_cuda_cpp.so)
frame #2: <unknown function> + 0x26aed47 (0x7ff9fe207d47 in /home/amir/repos/libtorch/lib/libtorch_cuda_cpp.so)
frame #3: <unknown function> + 0x26af144 (0x7ff9fe208144 in /home/amir/repos/libtorch/lib/libtorch_cuda_cpp.so)
frame #4: <unknown function> + 0x26a8c4c (0x7ff9fe201c4c in /home/amir/repos/libtorch/lib/libtorch_cuda_cpp.so)
frame #5: at::native::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) + 0x95 (0x7ff9fe202165 in /home/amir/repos/libtorch/lib/libtorch_cuda_cpp.so)
frame #6: <unknown function> + 0x2c734d6 (0x7ff9ae3384d6 in /home/amir/repos/libtorch/lib/libtorch_cuda_cu.so)
frame #7: <unknown function> + 0x2c7354f (0x7ff9ae33854f in /home/amir/repos/libtorch/lib/libtorch_cuda_cu.so)
frame #8: at::_ops::cudnn_convolution::call(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) + 0x23d (0x7ff9e5ba5f3d in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #9: at::native::_convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool) + 0xc80 (0x7ff9e5341300 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x1d69a3a (0x7ff9e5dc9a3a in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #11: at::_ops::_convolution::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool) + 0x277 (0x7ff9e58dd557 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #12: at::native::convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0xfb (0x7ff9e533a39b in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x1d697da (0x7ff9e5dc97da in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #14: at::_ops::convolution::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0x176 (0x7ff9e58b41d6 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x2891f18 (0x7ff9e68f1f18 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0x2892a66 (0x7ff9e68f2a66 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #17: at::_ops::convolution::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0x251 (0x7ff9e58dc1e1 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #18: at::native::conv2d(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long) + 0x159 (0x7ff9e533ec49 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #19: <unknown function> + 0x1dfa022 (0x7ff9e5e5a022 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #20: at::_ops::conv2d::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long) + 0x20e (0x7ff9e5c5f28e in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #21: <unknown function> + 0x385136c (0x7ff9e78b136c in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #22: torch::nn::Conv2dImpl::_conv_forward(at::Tensor const&, at::Tensor const&) + 0x3be (0x7ff9e78ab8ae in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #23: torch::nn::Conv2dImpl::forward(at::Tensor const&) + 0x10 (0x7ff9e78ab9b0 in /home/amir/repos/libtorch/lib/libtorch_cpu.so)
frame #24: dummy(thread_data) + 0x166 (0x5580a54a6d42 in ./build/fgprs)
frame #25: dummy2(void*) + 0x36 (0x5580a54a6f2d in ./build/fgprs)
frame #26: <unknown function> + 0x8609 (0x7ff9e4045609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #27: clone + 0x43 (0x7ff9ab024133 in /lib/x86_64-linux-gnu/libc.so.6)

I guess you might be running into an execution error since cuDNN will either use their heuristics or profiling (if torch.backends.cudnn.benchmark = True is used) to select a working kernel. If this selection is done in one context and executed in a different one (with a lower sm count) I would assume it could break. You could try to disable cuDNN via torch.backends.cudnn.enabled = False to avoid this issue.

It doesn’t work. I even tried to disable cuDNN by torch.backends.cudnn.enabled = False, the only difference is that this the runtime error changes to CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`.

And interestingly, it’s not even about using the same module in different contexts, even if I create a new input tensor and conv module in the new context, I still get the same error. Somehowpytorch does not like having multiple CUDA contexts.

A follow-up failure in cublas would be my second guess. Yes, you might be right as I’ve never seen this use case before and am sure nobody runs tests for it or ever executed it.
What’s your exact use case to create multiple contexts and execute different workloads there?

It’s part of my PhD thesis, I want to design a Real-Time scheduler for DL workloads. I want to create a pool of contexts and assign contexts to operations dynamically based on Real-Time constraints and deadlines. I’m thinking about migrating to Tensorflow but I want to make sure that I’m not doing something wrong with Libtorch first. I’ve spent lots of time getting comfortable with PyTorch’s API.

Do you have any idea why this is happening? This is a huge obstacle in front of my research right now. I considered other frameworks like TF or MXNet, but none of them have as mature C++ API as Pytorch.

No, as mentioned before, I would guess the different context properties are not supported if you are trying to reuse the same algorithms. Since it’s part of your research, I would recommend to scale down the problem first and try to run a simple workload in pure cublas/cuDNN using different contexts to see if these limitations are expected.