CUDNN_STATUS_MAPPING_ERROR runtime error when using multiple CUDA contexts

When I use the cuda-python library to create multiple contexts, I get the CUDNN_STATUS_MAPPING_ERROR runtime error. Here is the code to reproduce this error:

import torch
import torch.nn as nn
from cuda import cuda

affinity = cuda.CUexecAffinityParam()
affinity.type = cuda.CUexecAffinityType.CU_EXEC_AFFINITY_TYPE_SM_COUNT
affinity.param.smCount.val = 10
ctx = cuda.cuCtxCreate_v3([affinity], 1, 0, 0)
in1 = torch.ones(5000, 30, 10, 5, device="cuda")#.share_memory_()
conv1 = nn.Conv2d(30, 5, 3, stride=1, padding=1, device='cuda')
mod1 = MyModule("mod1", in1).cuda()
th1 = ReusableThread(mod1, 2)
temp = conv1(in1)

affinity = cuda.CUexecAffinityParam()
affinity.type = cuda.CUexecAffinityType.CU_EXEC_AFFINITY_TYPE_SM_COUNT
affinity.param.smCount.val = 40
ctx = cuda.cuCtxCreate_v3([affinity], 1, 0, 0)
res = cuda.cuCtxSetCurrent(ctx[1])
in2 = torch.ones(5000, 30, 10, 5, device="cuda")#.share_memory_()
conv2 = nn.Conv2d(30, 5, 3, stride=1, padding=1, device='cuda')
mod2 = MyModule("mod2", in2).cuda()
th2 = ReusableThread(mod2, 68)
temp = conv2(in2)

I get the error on the last line. Using raw CUDA I can use multiple different contexts in a single process without any problems but I get this error with PyTorch. (Remember you need to enable MPS with nvidia-cuda-mps-control -d first and it’s only available in Linux)