How can i run a pytorch model via onnx cuda runtime?

I have converted the model to onnx and it works on CPU but not able to pass in a cuda array through it. Any suggestions. Thanks

x_cpu = np.random.rand(1, 3, 256, 192)
x_gpu = cp.asarray(x_cpu)
x_gpu.shape

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

model_path = '/content/transpose_cuda.onnx'

providers = [
    ('CUDAExecutionProvider', {
        'device_id': 0,
        'arena_extend_strategy': 'kNextPowerOfTwo',
        'gpu_mem_limit': 2 * 1024 * 1024 * 1024,
        'cudnn_conv_algo_search': 'EXHAUSTIVE',
        'do_copy_in_default_stream': True,
    }),
    'CPUExecutionProvider',
]
ort_session = ort.InferenceSession(model_path, providers =['CUDAExecutionProvider'])

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: x_gpu}#to_numpy(input_tensor)} 
ort_outs = ort_session.run(None, ort_inputs)

#Comparing output tolerance from pytorch model versus onnx 
# np.testing.assert_allclose(pytorch_out, ort_outs[0], rtol=1e-03, atol=1e-05)
ort_outs[0].shape


RuntimeError Traceback (most recent call last)

in ()
18 # compute ONNX Runtime output prediction
19 ort_inputs = {ort_session.get_inputs()[0].name: x_gpu}#to_numpy(input_tensor)}
—> 20 ort_outs = ort_session.run(None, ort_inputs)
21
22 #Comparing output tolerance from pytorch model versus onnx

/usr/local/lib/python3.7/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py in run(self, output_names, input_feed, run_options)
190 output_names = [output.name for output in self._outputs_meta]
191 try:
→ 192 return self._sess.run(output_names, input_feed, run_options)
193 except C.EPFail as err:
194 if self._enable_fallback:

RuntimeError: Input must be a list of dictionaries or a single numpy array for input ‘input’.

I think even with the gpu execution provider, the standard thing is to pass in numpy (i.e. cpu) arrays. (Use assert 'CUDAExecutionProvider' in onnxruntime.get_available_providers() or nvidia-smi to check that you are using the GPU.)

Best regards

Thomas

Hey Tom, I am using gpu. I checked with:
import onnxruntime as ort
ort.get_device()

I referred to this page: