Using Pycuda and Pytorch Together

I’m having an issue using pycuda (for TensorRT) and pytorch together. When I move a “random_tensor” to the gpu the below script fails. If the “random_tensor” is left on the cpu this script completes without error. Is it something to do with cuda contexts clashing between pycuda and pytorch?

I can include more code if necessary.

import tensorrt as trt
import torch
import pycuda.driver as cudadriver
import pycuda.autoinit as cudacontext

random_tensor = torch.ones(1)
sample_tensor = torch.randint(0, 255, size=(3,32,32))
engine = UnifiedTensorRTEngine(
        tensorrt_file="./dummy.trt",
        output_shape=(1,10)
    )
use_gpu = True

if use_gpu:
    random_tensor = random_tensor.cuda()          #FAILS
else:
    pass                                          #PASSES

print(engine(sample_tensor))

Here’s the error that results when I move the tensor to gpu:

[TensorRT] ERROR: ../rtExt/cuda/cudaFusedConvActRunner.cpp (313) - Cuda Error in executeFused: 400 (invalid resource handle)
[TensorRT] ERROR: FAILED_EXECUTION: std::exception

And the output is all zeros:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

Here’s the source code for pycuda.autoinit that sets up the default cuda context:

from __future__ import absolute_import
import pycuda.driver as cuda

# Initialize CUDA
cuda.init()

from pycuda.tools import make_default_context
global context
context = make_default_context()
device = context.get_device()

def _finish_up():
    global context
    context.pop()
    context = None

    from pycuda.tools import clear_context_caches
    clear_context_caches()

import atexit
atexit.register(_finish_up)

My Environment:

Package Versions:
            numpy==1.19.0
            Pillow==5.4.1
            torch==1.5.0
            torchvision==0.6.0
            onnx=1.7.0
            onnxruntime=1.2.0
            pycuda==2019.1.2
            tensorrt==7.0.0.11
1 Like

I’m not sure where you are using pycuda in your code.
It seems you are trying to pass a PyTorch tensor to TRT, but also include pycuda for some reason.
Could you explain your use case a bit, please?

Sure, here’s a more expansive example.

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

So I trained this simple CIFAR model for a single epoch. Using the .pth file and the model, I generated the onnx file and further the tensorrt file. I would attach the .trt file, but it’s not an authorized extension.

from abc import ABC, abstractmethod
import torch
import numpy as np
import tensorrt as trt
import pycuda.driver as cudadriver
import pycuda.autoinit as cudacontext


class TensorRTEngine(ABC):
    def __init__(
        self,
        tensorrt_file: str,
        output_shape: tuple,
        multiple_returns: bool = False,
        return_slices: list = [],
    ) -> None:
        self.engine = self.load_tensorrt_engine(tensorrt_file)
        self.output_shape = output_shape
        self.multiple_returns = multiple_returns
        self.return_slices = return_slices
        self.execution_context = self.engine.create_execution_context()
        self.execution_stream = cudadriver.Stream()

    def load_tensorrt_engine(self, path):
        trt_logger = trt.Logger()
        with open(path, "rb") as f, trt.Runtime(trt_logger) as runtime:
            return runtime.deserialize_cuda_engine(f.read())

    def pre_process(self, image):
        if isinstance(image, torch.Tensor):
            image = image.cpu().numpy()
        image = image.astype(np.float32)
        return image

    def post_process(self, output_tensors):
        rstacked = output_tensors.view(self.output_shape)
        if self.multiple_returns:
            return [rstacked[s] for s in self.return_slices]
        else:
            return rstacked

    @abstractmethod
    def allocate_memory_buffers(self):
        pass


class UnifiedTensorRTEngine(TensorRTEngine):
    def __init__(
        self,
        tensorrt_file: str,
        output_shape: tuple,
        multiple_returns: bool = False,
        return_slices: list = [],
    ):
        super().__init__(tensorrt_file, output_shape, multiple_returns, return_slices)
        self.in_memory, self.out_memory, self.bindings = self.allocate_memory_buffers()

    def __call__(self, image):
        image = self.pre_process(image)
        self.in_memory[:, :, :] = image
        self.execution_context.execute_async_v2(
            bindings=self.bindings, stream_handle=self.execution_stream.handle
        )
        self.execution_stream.synchronize()
        stacked = torch.stack([torch.Tensor(out) for out in self.out_memory])
        post_processed_output = self.post_process(stacked)
        return post_processed_output

    def allocate_memory_buffers(self):
        dtypes, bindings = [], []
        for binding in self.engine:
            dtypes.append(trt.nptype(self.engine.get_binding_dtype(binding)))
        in_shape = tuple(self.engine.get_binding_shape(0))
        out_shape = tuple(self.engine.get_binding_shape(1))
        input_data_size = trt.volume(in_shape) * self.engine.max_batch_size
        output_data_size = trt.volume(out_shape) * self.engine.max_batch_size
        in_memory = cudadriver.managed_empty(
            shape=in_shape,
            dtype=dtypes[0],
            mem_flags=cudadriver.mem_attach_flags.GLOBAL,
        )
        out_memory = cudadriver.managed_empty(
            shape=out_shape,
            dtype=dtypes[1],
            mem_flags=cudadriver.mem_attach_flags.GLOBAL,
        )
        self.execution_stream.synchronize()
        bindings.append(int(in_memory.base.get_device_pointer()))
        bindings.append(int(out_memory.base.get_device_pointer()))
        return in_memory, out_memory, bindings


if __name__ == "__main__":
    engine = UnifiedTensorRTEngine(tensorrt_file="dummy.trt", output_shape=(1, -1))
    full_image_tensor = torch.randint(0, 255, size=(3, 32, 32))
    use_gpu = True
    if use_gpu:
        image_tensor = image_tensor.cuda()  # FAILS
    else:
        pass                                # PASSES
    out = engine(image_tensor)
    print(out)

Here’s the error that occurs:

[TensorRT] ERROR: ../rtExt/cuda/cudaFusedConvActRunner.cpp (313) - Cuda Error in executeFused: 400 (invalid resource handle)
[TensorRT] ERROR: FAILED_EXECUTION: std::exception
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

After inference occurs there are some further post-process steps that use pytorch so being able to use tensorrt and pytorch in the same process is important.

1 Like

Hi, have you solved the problem?

Yes. Torch shouldn’t be necessary in this case. Cast the torch tensor to a numpy array and pass that to the engine.

In case anybody needs it, I’ve written a small library to transfer torch tensor to pycuda for inplace manipulation. Please check out cutex. It is used in my personal projects for processing images and implementing novel layers of neural networks (through autograd.Function).