Running torchscript on multiple GPUs

Hi,
I am trying to run an pytorch object detector using triton server. I used tracing for the model and scripting for the post-processing function. For a single GPU the torchscript runs smoothly on the server.

But, in the multi-GPU case, I am not able to run the compiled script. This is because the scripted post-process function memories the GPU id (cuda:0) which I used to run the torchscript and expects all the tensor operations to be performed using that id. This invariably fails when the triton server passes any other cuda device.

Is there any workaround this?

1 Like

How are you defining the multi-GPU use case? Could you explain your deployment and where it’s currently breaking?

In multi-GPU usecase I start one triton server on two GPUs and place one instance of the torchscript model on each GPU using triton config.

I am using triton server. What other information is required?

  • The triton server is running on cuda:0 and cuda:1
  • Torchscript was complied using cuda:0

The error is caused by model instance on cuda:1 of the server fails with the error message expected device cuda:1 but got device cuda:0

Would it work, if you write the PyTorch model device-agnostic, i.e. use cuda:0 as now and mask the other GPUs with CUDA_VISIBLE_DEVICES?

As I have to load the same traced model to both cuda:0 and cuda:1 in the same script, I cannot mask any of the GPUs as you suggest.

The workaround I am using is to have two redundant traces on both cuda:0 and cuda:1 and load the right file to the right device. However, this double the space required for the model in deployment. And what if the third GPU is added? not very scalable.

It would be nice if torch.jit.load’s map_location can accept dictionary to allow me to map from cuda:0 to any cuda device I want and also keep cpu device as it is.
Is there any reason that torch.jit.load’s map_location doesn’t accept dictionary like torch.load does?

TorchScript is in maintenance mode and not ins active development anymore so I would recommend switching to the torch.compile stack.